1from membase.api.rest_client import RestConnection, RestHelper
2from memcached.helper.data_helper import MemcachedClientHelper
3from remote.remote_util import RemoteMachineShellConnection
4from mc_bin_client import MemcachedClient, MemcachedError
5from membase.api.exception import ServerAlreadyJoinedException
6from membase.helper.rebalance_helper import RebalanceHelper
7import memcacheConstants
8
9import logger
10import testconstants
11import time
12import Queue
13from threading import Thread
14import traceback
15
16
17class ClusterOperationHelper(object):
18    #the first ip is taken as the master ip
19
20    # Returns True if cluster successfully finished ther rebalance
21    @staticmethod
22    def add_and_rebalance(servers, wait_for_rebalance=True):
23        log = logger.Logger.get_logger()
24        master = servers[0]
25        all_nodes_added = True
26        rebalanced = True
27        rest = RestConnection(master)
28        if len(servers) > 1:
29            for serverInfo in servers[1:]:
30                log.info('adding node : {0}:{1} to the cluster'.format(
31                        serverInfo.ip, serverInfo.port))
32                otpNode = rest.add_node(master.rest_username, master.rest_password, serverInfo.ip, port=serverInfo.port)
33                if otpNode:
34                    log.info('added node : {0} to the cluster'.format(otpNode.id))
35                else:
36                    all_nodes_added = False
37                    break
38            if all_nodes_added:
39                rest.rebalance(otpNodes=[node.id for node in rest.node_statuses()], ejectedNodes=[])
40                if wait_for_rebalance:
41                    rebalanced &= rest.monitorRebalance()
42                else:
43                    rebalanced = False
44        return all_nodes_added and rebalanced
45
46    @staticmethod
47    def add_all_nodes_or_assert(master, all_servers, rest_settings, test_case):
48        log = logger.Logger.get_logger()
49        otpNodes = []
50        all_nodes_added = True
51        rest = RestConnection(master)
52        for serverInfo in all_servers:
53            if serverInfo.ip != master.ip:
54                log.info('adding node : {0}:{1} to the cluster'.format(
55                        serverInfo.ip, serverInfo.port))
56                otpNode = rest.add_node(rest_settings.rest_username,
57                                        rest_settings.rest_password,
58                                        serverInfo.ip)
59                if otpNode:
60                    log.info('added node : {0} to the cluster'.format(otpNode.id))
61                    otpNodes.append(otpNode)
62                else:
63                    all_nodes_added = False
64        if not all_nodes_added:
65            if test_case:
66                test_case.assertTrue(all_nodes_added,
67                                     msg="unable to add all the nodes to the cluster")
68            else:
69                log.error("unable to add all the nodes to the cluster")
70        return otpNodes
71
72    #wait_if_warmup=True is useful in tearDown method for (auto)failover tests
73    @staticmethod
74    def wait_for_ns_servers_or_assert(servers, testcase, wait_time=360, wait_if_warmup=False):
75        for server in servers:
76            rest = RestConnection(server)
77            log = logger.Logger.get_logger()
78            log.info("waiting for ns_server @ {0}:{1}".format(server.ip, server.port))
79            if RestHelper(rest).is_ns_server_running(wait_time):
80                log.info("ns_server @ {0}:{1} is running".format(server.ip, server.port))
81            elif wait_if_warmup:
82                # wait when warmup completed
83                buckets = rest.get_buckets()
84                for bucket in buckets:
85                    testcase.assertTrue(ClusterOperationHelper._wait_warmup_completed(testcase, \
86                                [server], bucket.name, wait_time), "warmup was not completed!")
87
88            else:
89                testcase.fail("ns_server {0} is not running in {1} sec".format(server.ip, wait_time))
90
91    #returns true if warmup is completed in wait_time sec
92    #otherwise return false
93    @staticmethod
94    def _wait_warmup_completed(self, servers, bucket_name, wait_time=300):
95        warmed_up = False
96        log = logger.Logger.get_logger()
97        for server in servers:
98            mc = None
99            start = time.time()
100            # Try to get the stats for 5 minutes, else hit out.
101            while time.time() - start < wait_time:
102                # Get the wamrup time for each server
103                try:
104                    mc = MemcachedClientHelper.direct_client(server, bucket_name)
105                    stats = mc.stats()
106                    if stats is not None:
107                        warmup_time = int(stats["ep_warmup_time"])
108                        log.info("ep_warmup_time is %s " % warmup_time)
109                        log.info(
110                            "Collected the stats 'ep_warmup_time' %s for server %s:%s" %
111                                (stats["ep_warmup_time"], server.ip, server.port))
112                        break
113                    else:
114                        log.info(" Did not get the stats from the server yet, trying again.....")
115                        time.sleep(2)
116                except Exception as e:
117                    log.error(
118                        "Could not get ep_warmup_time stats from server %s:%s, exception %s" %
119                             (server.ip, server.port, e))
120            else:
121                self.fail(
122                    "Fail! Unable to get the warmup-stats from server %s:%s after trying for %s seconds." % (
123                        server.ip, server.port, wait_time))
124
125            # Waiting for warm-up
126            start = time.time()
127            warmed_up = False
128            while time.time() - start < wait_time and not warmed_up:
129                if mc.stats()["ep_warmup_thread"] == "complete":
130                    log.info("warmup completed, awesome!!! Warmed up. %s items " % (mc.stats()["curr_items_tot"]))
131                    warmed_up = True
132                    continue
133                elif mc.stats()["ep_warmup_thread"] == "running":
134                    log.info(
135                                "still warming up .... curr_items_tot : %s" % (mc.stats()["curr_items_tot"]))
136                else:
137                    fail("Value of ep warmup thread does not exist, exiting from this server")
138                time.sleep(5)
139            mc.close()
140        return warmed_up
141
142    @staticmethod
143    def verify_persistence(servers, test, keys_count=400000, timeout_in_seconds=300):
144        log = logger.Logger.get_logger()
145        master = servers[0]
146        rest = RestConnection(master)
147        log.info("Verifying Persistence")
148        buckets = rest.get_buckets()
149        for bucket in buckets:
150        #Load some data
151            l_threads = MemcachedClientHelper.create_threads([master], bucket.name,
152                                                                     - 1, keys_count, {1024: 0.50, 512: 0.50}, 2, -1,
153                                                                     True, True)
154            [t.start() for t in l_threads]
155            # Do persistence verification
156            ready = ClusterOperationHelper.persistence_verification(servers, bucket.name, timeout_in_seconds)
157            log.info("Persistence Verification returned ? {0}".format(ready))
158            log.info("waiting for persistence threads to finish...")
159            for t in l_threads:
160                t.aborted = True
161            for t in l_threads:
162                t.join()
163            log.info("persistence thread has finished...")
164            test.assertTrue(ready, msg="Cannot verify persistence")
165
166
167    @staticmethod
168    def persistence_verification(servers, bucket, timeout_in_seconds=1260):
169        log = logger.Logger.get_logger()
170        verification_threads = []
171        queue = Queue.Queue()
172        rest = RestConnection(servers[0])
173        nodes = rest.get_nodes()
174        nodes_ip = []
175        for node in nodes:
176            nodes_ip.append(node.ip)
177        for i in range(len(servers)):
178            if servers[i].ip in nodes_ip:
179                log.info("Server {0}:{1} part of cluster".format(
180                        servers[i].ip, servers[i].port))
181                rest = RestConnection(servers[i])
182                t = Thread(target=ClusterOperationHelper.persistence_verification_per_node,
183                           name="verification-thread-{0}".format(servers[i]),
184                           args=(rest, bucket, queue, timeout_in_seconds))
185                verification_threads.append(t)
186        for t in verification_threads:
187            t.start()
188        for t in verification_threads:
189            t.join()
190            log.info("thread {0} finished".format(t.name))
191        while not queue.empty():
192            item = queue.get()
193            if item is False:
194                return False
195        return True
196
197    @staticmethod
198    def persistence_verification_per_node(rest, bucket, queue=None, timeout=1260):
199        log = logger.Logger.get_logger()
200        stat_key = 'ep_flusher_todo'
201        start = time.time()
202        stats = []
203        # Collect stats data points
204        while time.time() - start <= timeout:
205            _new_stats = rest.get_bucket_stats(bucket)
206            if _new_stats and 'ep_flusher_todo' in _new_stats:
207                stats.append(_new_stats[stat_key])
208                time.sleep(0.5)
209            else:
210                log.error("unable to obtain stats for bucket : {0}".format(bucket))
211        value_90th = ClusterOperationHelper.percentile(stats, 90)
212        average = float(sum(stats)) / len(stats)
213        log.info("90th percentile value is {0} and average {1}".format(value_90th, average))
214        if value_90th == 0 and average == 0:
215            queue.put(False)
216            return
217        queue.put(True)
218
219    @staticmethod
220    def percentile(samples, percentile):
221        element_idx = int(len(samples) * (percentile / 100.0))
222        samples.sort()
223        value = samples[element_idx]
224        return value
225
226    @staticmethod
227    def start_cluster(servers):
228        for server in servers:
229            shell = RemoteMachineShellConnection(server)
230            if shell.is_couchbase_installed():
231                shell.start_couchbase()
232            else:
233                shell.start_membase()
234
235    @staticmethod
236    def stop_cluster(servers):
237        for server in servers:
238            shell = RemoteMachineShellConnection(server)
239            if shell.is_couchbase_installed():
240                shell.stop_couchbase()
241            else:
242                shell.stop_membase()
243
244    @staticmethod
245    def cleanup_cluster(servers, wait_for_rebalance=True, master = None):
246        log = logger.Logger.get_logger()
247        if master == None:
248            master = servers[0]
249        rest = RestConnection(master)
250        helper = RestHelper(rest)
251        helper.is_ns_server_running(timeout_in_seconds=testconstants.NS_SERVER_TIMEOUT)
252        nodes = rest.node_statuses()
253        master_id = rest.get_nodes_self().id
254        if len(nodes) > 1:
255            log.info("rebalancing all nodes in order to remove nodes")
256            rest.log_client_error("Starting rebalance from test, ejected nodes %s" % \
257                                                             [node.id for node in nodes if node.id != master_id])
258            removed = helper.remove_nodes(knownNodes=[node.id for node in nodes],
259                                          ejectedNodes=[node.id for node in nodes if node.id != master_id],
260                                          wait_for_rebalance=wait_for_rebalance)
261            success_cleaned = []
262            for removed in [node for node in nodes if (node.id != master_id)]:
263                removed.rest_password = servers[0].rest_password
264                removed.rest_username = servers[0].rest_username
265                try:
266                    rest = RestConnection(removed)
267                except Exception as ex:
268                    log.error("can't create rest connection after rebalance out for ejected nodes,\
269                        will retry after 10 seconds according to MB-8430: {0} ".format(ex))
270                    time.sleep(10)
271                    rest = RestConnection(removed)
272                start = time.time()
273                while time.time() - start < 30:
274                    if len(rest.get_pools_info()["pools"]) == 0:
275                        success_cleaned.append(removed)
276                        break
277                    else:
278                        time.sleep(0.1)
279                if time.time() - start > 10:
280                    log.error("'pools' on node {0}:{1} - {2}".format(
281                           removed.ip, removed.port, rest.get_pools_info()["pools"]))
282            for node in set([node for node in nodes if (node.id != master_id)]) - set(success_cleaned):
283                log.error("node {0}:{1} was not cleaned after removing from cluster".format(
284                           removed.ip, removed.port))
285                try:
286                    rest = RestConnection(node)
287                    rest.force_eject_node()
288                except Exception as ex:
289                    log.error("force_eject_node {0}:{1} failed: {2}".format(removed.ip, removed.port, ex))
290            if len(set([node for node in nodes if (node.id != master_id)])\
291                    - set(success_cleaned)) != 0:
292                raise Exception("not all ejected nodes were cleaned successfully")
293
294            log.info("removed all the nodes from cluster associated with {0} ? {1}".format(servers[0], \
295                    [(node.id, node.port) for node in nodes if (node.id != master_id)]))
296
297    @staticmethod
298    def flushctl_start(servers, username=None, password=None):
299        for server in servers:
300            c = MemcachedClient(server.ip, 11210)
301            if username:
302                c.sasl_auth_plain(username, password)
303            c.start_persistence()
304
305    @staticmethod
306    def flushctl_stop(servers, username=None, password=None):
307        for server in servers:
308            c = MemcachedClient(server.ip, 11210)
309            if username:
310                c.sasl_auth_plain(username, password)
311            c.stop_persistence()
312
313    @staticmethod
314    def flush_os_caches(servers):
315        log = logger.Logger.get_logger()
316        for server in servers:
317            try:
318                shell = RemoteMachineShellConnection(server)
319                shell.flush_os_caches()
320                log.info("Clearing os caches on {0}".format(server))
321            except:
322                pass
323
324    @staticmethod
325    def flushctl_set(master, key, val, bucket='default'):
326        rest = RestConnection(master)
327        servers = rest.get_nodes()
328        for server in servers:
329            _server = {"ip": server.ip, "port": server.port,
330                       "username": master.rest_username,
331                       "password": master.rest_password}
332            ClusterOperationHelper.flushctl_set_per_node(_server, key, val, bucket)
333
334    @staticmethod
335    def flushctl_set_per_node(server, key, val, bucket='default'):
336        log = logger.Logger.get_logger()
337        rest = RestConnection(server)
338        node = rest.get_nodes_self()
339        mc = MemcachedClientHelper.direct_client(server, bucket)
340        log.info("Setting flush param on server {0}, {1} to {2} on {3}".format(server, key, val, bucket))
341        # Workaround for CBQE-249, ideally this should be node.version
342        index_path = node.storage[0].get_index_path()
343        if index_path is '':
344            # Indicates non 2.0 build
345            rv = mc.set_flush_param(key, str(val))
346        else:
347            type = ClusterOperationHelper._get_engine_param_type(key)
348            rv = mc.set_param(key, str(val), type)
349        log.info("Setting flush param on server {0}, {1} to {2}, result: {3}".format(server, key, val, rv))
350        mc.close()
351
352    @staticmethod
353    def _get_engine_param_type(key):
354        tap_params = ['tap_keepalive', 'tap_throttle_queue_cap', 'tap_throttle_threshold']
355        checkpoint_params = ['chk_max_items', 'chk_period', 'inconsistent_slave_chk', 'keep_closed_chks',
356                             'max_checkpoints', 'item_num_based_new_chk']
357        flush_params = ['bg_fetch_delay', 'couch_response_timeout', 'exp_pager_stime', 'flushall_enabled',
358                        'klog_compactor_queue_cap', 'klog_max_log_size', 'klog_max_entry_ratio',
359                        'queue_age_cap', 'max_size', 'max_txn_size', 'mem_high_wat', 'mem_low_wat',
360                        'min_data_age', 'timing_log', 'alog_sleep_time']
361        if key in tap_params:
362            return memcacheConstants.ENGINE_PARAM_TAP
363        if key in checkpoint_params:
364            return memcacheConstants.ENGINE_PARAM_CHECKPOINT
365        if key in flush_params:
366            return memcacheConstants.ENGINE_PARAM_FLUSH
367
368    @staticmethod
369    def set_expiry_pager_sleep_time(master, bucket, value=30):
370        log = logger.Logger.get_logger()
371        rest = RestConnection(master)
372        servers = rest.get_nodes()
373        for server in servers:
374            #this is not bucket specific so no need to pass in the bucketname
375            log.info("connecting to memcached {0}:{1}".format(server.ip, server.memcached))
376            mc = MemcachedClientHelper.direct_client(server, bucket)
377            log.info("Set exp_pager_stime flush param on server {0}:{1}".format(server.ip, server.port))
378            try:
379                mc.set_flush_param("exp_pager_stime", str(value))
380                log.info("Set exp_pager_stime flush param on server {0}:{1}".format(server.ip, server.port))
381            except Exception as ex:
382                traceback.print_exc()
383                log.error("Unable to set exp_pager_stime flush param on memcached {0}:{1}".format(server.ip, server.memcached))
384
385    @staticmethod
386    def get_mb_stats(servers, key):
387        log = logger.Logger.get_logger()
388        for server in servers:
389            c = MemcachedClient(server.ip, 11210)
390            log.info("Get flush param on server {0}, {1}".format(server, key))
391            value = c.stats().get(key, None)
392            log.info("Get flush param on server {0}, {1}".format(server, value))
393            c.close()
394
395    @staticmethod
396    def change_erlang_threads_values(servers, sync_threads=True, num_threads='16:16'):
397        """Change the the type of sync erlang threads and its value
398           sync_threads=True means sync threads +S with default threads number equal 16:16
399           sync_threads=False means async threads: +A 16, for instance
400
401        Default: +S 16:16
402        """
403        log = logger.Logger.get_logger()
404        for server in servers:
405            sh = RemoteMachineShellConnection(server)
406            product = "membase"
407            if sh.is_couchbase_installed():
408                product = "couchbase"
409
410            sync_type = sync_threads and "S" or "A"
411
412            command = "sed -i 's/+[A,S] .*/+%s %s \\\/g' /opt/%s/bin/%s-server" % \
413                 (sync_type, num_threads, product, product)
414            o, r = sh.execute_command(command)
415            sh.log_command_output(o, r)
416            msg = "modified erlang +%s to %s for server %s"
417            log.info(msg % (sync_type, num_threads, server.ip))
418
419    @staticmethod
420    def set_erlang_schedulers(servers, value="16:16"):
421        """
422        Set num of erlang schedulers.
423        Also erase async option (+A)
424        """
425        ClusterOperationHelper.stop_cluster(servers)
426
427        log = logger.Logger.get_logger()
428        for server in servers:
429            sh = RemoteMachineShellConnection(server)
430            product = "membase"
431            if sh.is_couchbase_installed():
432                product = "couchbase"
433            command = "sed -i 's/S\+ 128:128/S %s/' /opt/%s/bin/%s-server"\
434                      % (value, product, product)
435            o, r = sh.execute_command(command)
436            sh.log_command_output(o, r)
437            log.info("modified erlang +A to %s for server %s"
438                     % (value, server.ip))
439
440        ClusterOperationHelper.start_cluster(servers)
441
442    @staticmethod
443    def change_erlang_gc(servers, value=None):
444        """Change the frequency of erlang_gc process
445           export ERL_FULLSWEEP_AFTER=0 (most aggressive)
446
447        Default: None
448        """
449        log = logger.Logger.get_logger()
450        if value is None:
451            return
452        for server in servers:
453            sh = RemoteMachineShellConnection(server)
454            product = "membase"
455            if sh.is_couchbase_installed():
456                product = "couchbase"
457            command = "sed -i '/exec erl/i export ERL_FULLSWEEP_AFTER=%s' /opt/%s/bin/%s-server" % \
458                      (value, product, product)
459            o, r = sh.execute_command(command)
460            sh.log_command_output(o, r)
461            msg = "modified erlang gc to full_sweep_after %s on %s " % (value, server.ip)
462            log.info(msg)
463
464    @staticmethod
465    def begin_rebalance_in(master, servers, timeout=5):
466        RebalanceHelper.begin_rebalance_in(master, servers, timeout)
467
468    @staticmethod
469    def begin_rebalance_out(master, servers, timeout=5):
470        RebalanceHelper.begin_rebalance_out(master, servers, timeout)
471
472    @staticmethod
473    def end_rebalance(master):
474        RebalanceHelper.end_rebalance(master)
475
476    @staticmethod
477    # Returns the otpNode for Orchestrator
478    def find_orchestrator(master):
479        rest = RestConnection(master)
480        command = "node(global:whereis_name(ns_orchestrator))"
481        status, content = rest.diag_eval(command)
482        # Get rid of single quotes 'ns_1@10.1.3.74'
483        content = content.replace("'", '')
484        return status, content
485
486    @staticmethod
487    def set_vbuckets(master, vbuckets):
488        rest = RestConnection(master)
489        command = "rpc:eval_everywhere(ns_config, set, [couchbase_num_vbuckets_default, {0}]).".format(vbuckets)
490        status, content = rest.diag_eval(command)
491        return status, content
492