xref: /6.0.3/couchbase-cli/pump.py (revision 25816fd8)
1#!/usr/bin/env python
2
3import os
4import base64
5import copy
6import httplib
7import logging
8import re
9import Queue
10import json
11import string
12import sys
13import threading
14import time
15import urllib
16import urlparse
17import zlib
18import platform
19import subprocess
20import socket
21
22import couchbaseConstants
23import cb_bin_client
24from cb_util import tag_user_data
25from cluster_manager import ClusterManager
26from collections import defaultdict
27import cbsnappy as snappy
28
29# TODO: (1) optionally log into backup directory
30
31LOGGING_FORMAT = '%(asctime)s: %(threadName)s %(message)s'
32
33NA = 'N/A'
34
35class ProgressReporter(object):
36    """Mixin to report progress"""
37
38    def report_init(self):
39        self.beg_time = time.time()
40        self.prev_time = self.beg_time
41        self.prev = defaultdict(int)
42
43    def report(self, prefix="", emit=None):
44        if not emit:
45            emit = logging.info
46
47        if getattr(self, "source", None):
48            emit(prefix + "source : %s" % (self.source))
49        if getattr(self, "sink", None):
50            emit(prefix + "sink   : %s" % (self.sink))
51
52        cur_time = time.time()
53        delta = cur_time - self.prev_time
54        c, p = self.cur, self.prev
55        x = sorted([k for k in c.iterkeys() if "_sink_" in k])
56
57        width_k = max([5] + [len(k.replace("tot_sink_", "")) for k in x])
58        width_v = max([20] + [len(str(c[k])) for k in x])
59        width_d = max([10] + [len(str(c[k] - p[k])) for k in x])
60        width_s = max([10] + [len("%0.1f" % ((c[k] - p[k]) / delta)) for k in x])
61        emit(prefix + " %s : %s | %s | %s"
62             % (string.ljust("", width_k),
63                string.rjust("total", width_v),
64                string.rjust("last", width_d),
65                string.rjust("per sec", width_s)))
66        verbose_set = ["tot_sink_batch", "tot_sink_msg"]
67        for k in x:
68            if k not in verbose_set or self.opts.verbose > 0:
69                emit(prefix + " %s : %s | %s | %s"
70                 % (string.ljust(k.replace("tot_sink_", ""), width_k),
71                    string.rjust(str(c[k]), width_v),
72                    string.rjust(str(c[k] - p[k]), width_d),
73                    string.rjust("%0.1f" % ((c[k] - p[k]) / delta), width_s)))
74        self.prev_time = cur_time
75        self.prev = copy.copy(c)
76
77    def bar(self, current, total):
78        if not total:
79            return '.'
80        if sys.platform.lower().startswith('win'):
81            cr = "\r"
82        else:
83            cr = chr(27) + "[A\n"
84        pct = float(current) / total
85        max_hash = 20
86        num_hash = int(round(pct * max_hash))
87        return ("  [%s%s] %0.1f%% (%s/estimated %s msgs)%s" %
88                ('#' * num_hash, ' ' * (max_hash - num_hash),
89                 100.0 * pct, current, total, cr))
90
91class PumpingStation(ProgressReporter):
92    """Queues and watchdogs multiple pumps across concurrent workers."""
93
94    def __init__(self, opts, source_class, source_spec, sink_class, sink_spec):
95        self.opts = opts
96        self.source_class = source_class
97        self.source_spec = source_spec
98        self.sink_class = sink_class
99        self.sink_spec = sink_spec
100        self.queue = None
101        tmstamp = time.strftime("%Y-%m-%dT%H%M%SZ", time.gmtime())
102        self.ctl = { 'stop': False,
103                     'rv': 0,
104                     'new_session': True,
105                     'new_timestamp': tmstamp}
106        self.cur = defaultdict(int)
107
108    def run(self):
109        # TODO: (6) PumpingStation - monitor source for topology changes.
110        # TODO: (4) PumpingStation - retry on err N times, M times / server.
111        # TODO: (2) PumpingStation - track checksum in backup, for later restore.
112
113        rv, source_map, sink_map = self.check_endpoints()
114        if rv != 0:
115            return rv
116
117        if self.opts.dry_run:
118            sys.stderr.write("done, but no data written due to dry-run\n")
119            return 0
120
121        source_buckets = self.filter_source_buckets(source_map)
122        if not source_buckets:
123            bucket_source = getattr(self.opts, "bucket_source", None)
124            if bucket_source:
125                return ("error: there is no bucket: %s at source: %s" %
126                        (bucket_source, self.source_spec))
127            else:
128                return ("error: no transferrable buckets at source: %s" %
129                        (self.source_spec))
130
131        for source_bucket in sorted(source_buckets,
132                                    key=lambda b: b['name']):
133            logging.info("bucket: " + source_bucket['name'])
134
135            if not self.opts.extra.get("design_doc_only", 0):
136                rv = self.transfer_bucket_msgs(source_bucket, source_map, sink_map)
137                if rv != 0:
138                    return rv
139            else:
140                sys.stderr.write("transfer design doc only. bucket msgs will be skipped.\n")
141
142            if not self.opts.extra.get("data_only", 0):
143                rv = self.transfer_bucket_design(source_bucket, source_map, sink_map)
144                if rv:
145                    logging.warn(rv)
146                rv = self.transfer_bucket_index(source_bucket, source_map, sink_map)
147                if rv:
148                    logging.warn(rv)
149                rv = self.transfer_bucket_fts_index(source_bucket, source_map, sink_map)
150                if rv:
151                    logging.warn(rv)
152
153            else:
154                sys.stderr.write("transfer data only. bucket design docs and index meta will be skipped.\n")
155
156            # TODO: (5) PumpingStation - validate bucket transfers.
157
158        # TODO: (4) PumpingStation - validate source/sink maps were stable.
159
160        sys.stderr.write("done\n")
161        return 0
162
163    def check_endpoints(self):
164        logging.debug("source_class: %s", self.source_class)
165        rv = self.source_class.check_base(self.opts, self.source_spec)
166        if rv != 0:
167            return rv, None, None
168        rv, source_map = self.source_class.check(self.opts, self.source_spec)
169        if rv != 0:
170            return rv, None, None
171
172        logging.debug("sink_class: %s", self.sink_class)
173        rv = self.sink_class.check_base(self.opts, self.sink_spec)
174        if rv != 0:
175            return rv, None, None
176        rv, sink_map = self.sink_class.check(self.opts, self.sink_spec, source_map)
177        if rv != 0:
178            return rv, None, None
179
180        return rv, source_map, sink_map
181
182    def filter_source_buckets(self, source_map):
183        """Filter the source_buckets if a bucket_source was specified."""
184        source_buckets = source_map['buckets']
185        logging.debug("source_buckets: " +
186                      ",".join([n['name'] for n in source_buckets]))
187
188        bucket_source = getattr(self.opts, "bucket_source", None)
189        if bucket_source:
190            logging.debug("bucket_source: " + bucket_source)
191            source_buckets = [b for b in source_buckets
192                              if b['name'] == bucket_source]
193            logging.debug("source_buckets filtered: " +
194                          ",".join([n['name'] for n in source_buckets]))
195        return source_buckets
196
197    def filter_source_nodes(self, source_bucket, source_map):
198        """Filter the source_bucket's nodes if single_node was specified."""
199        if getattr(self.opts, "single_node", None):
200            if not source_map.get('spec_parts'):
201                return ("error: no single_node from source: %s" +
202                        "; the source may not support the --single-node flag") % \
203                        (self.source_spec)
204            source_nodes = filter_bucket_nodes(source_bucket,
205                                               source_map.get('spec_parts'))
206        else:
207            source_nodes = source_bucket['nodes']
208
209        logging.debug(" source_nodes: " + ",".join([n.get('hostname', NA)
210                                                    for n in source_nodes]))
211        return source_nodes
212
213    def transfer_bucket_msgs(self, source_bucket, source_map, sink_map):
214        source_nodes = self.filter_source_nodes(source_bucket, source_map)
215
216        # Transfer bucket msgs with a Pump per source server.
217        self.start_workers(len(source_nodes))
218        self.report_init()
219
220        self.ctl['run_msg'] = 0
221        self.ctl['tot_msg'] = 0
222
223        for source_node in sorted(source_nodes,
224                                  key=lambda n: n.get('hostname', NA)):
225            logging.debug(" enqueueing node: " +
226                          source_node.get('hostname', NA))
227            self.queue.put((source_bucket, source_node, source_map, sink_map))
228
229            rv, tot = self.source_class.total_msgs(self.opts,
230                                                   source_bucket,
231                                                   source_node,
232                                                   source_map)
233            if rv != 0:
234                return rv
235            if tot:
236                self.ctl['tot_msg'] += tot
237
238        # Don't use queue.join() as it eats Ctrl-C's.
239        s = 0.05
240        while self.queue.unfinished_tasks:
241            time.sleep(s)
242            s = min(1.0, s + 0.01)
243
244        rv = self.ctl['rv']
245        if rv != 0:
246            return rv
247
248        time.sleep(0.01) # Allows threads to update counters.
249
250        sys.stderr.write(self.bar(self.ctl['run_msg'],
251                                  self.ctl['tot_msg']) + "\n")
252        sys.stderr.write("bucket: " + source_bucket['name'] +
253                         ", msgs transferred...\n")
254        def emit(msg):
255            sys.stderr.write(msg + "\n")
256        self.report(emit=emit)
257
258        return 0
259
260    def transfer_bucket_design(self, source_bucket, source_map, sink_map):
261        """Transfer bucket design (e.g., design docs, views)."""
262        rv, source_design = \
263            self.source_class.provide_design(self.opts, self.source_spec,
264                                             source_bucket, source_map)
265        if rv == 0:
266            if source_design:
267                sources = source_design if isinstance(source_design, list) else [source_design]
268                for source_design in sources:
269                    rv = self.sink_class.consume_design(self.opts,
270                                                self.sink_spec, sink_map,
271                                                source_bucket, source_map,
272                                                source_design)
273        return rv
274
275    def transfer_bucket_index(self, source_bucket, source_map, sink_map):
276        """Transfer bucket index meta."""
277        rv, source_design = \
278            self.source_class.provide_index(self.opts, self.source_spec,
279                                             source_bucket, source_map)
280        if rv == 0:
281            if source_design:
282                rv = self.sink_class.consume_index(self.opts,
283                                                self.sink_spec, sink_map,
284                                                source_bucket, source_map,
285                                                source_design)
286        return rv
287
288    def transfer_bucket_fts_index(self, source_bucket, source_map, sink_map):
289        """Transfer bucket index meta."""
290        rv, source_design = \
291            self.source_class.provide_fts_index(self.opts, self.source_spec,
292                                                source_bucket, source_map)
293        if rv == 0:
294            if source_design:
295                rv = self.sink_class.consume_fts_index(self.opts,
296                                                   self.sink_spec, sink_map,
297                                                   source_bucket, source_map,
298                                                   source_design)
299        return rv
300
301    @staticmethod
302    def run_worker(self, thread_index):
303        while not self.ctl['stop']:
304            source_bucket, source_node, source_map, sink_map = \
305                self.queue.get()
306            hostname = source_node.get('hostname', NA)
307            logging.debug(" node: %s" % (hostname))
308
309            curx = defaultdict(int)
310            self.source_class.check_spec(source_bucket,
311                                         source_node,
312                                         self.opts,
313                                         self.source_spec,
314                                         curx)
315            self.sink_class.check_spec(source_bucket,
316                                       source_node,
317                                       self.opts,
318                                       self.sink_spec,
319                                       curx)
320
321            source = self.source_class(self.opts, self.source_spec, source_bucket,
322                                       source_node, source_map, sink_map, self.ctl,
323                                       curx)
324            sink = self.sink_class(self.opts, self.sink_spec, source_bucket,
325                                   source_node, source_map, sink_map, self.ctl,
326                                   curx)
327
328            src_conf_res = source.get_conflict_resolution_type()
329            snk_conf_res = sink.get_conflict_resolution_type()
330            _, snk_bucket = find_sink_bucket_name(self.opts, source_bucket["name"])
331
332            forced = False
333            if int(self.opts.extra.get("try_xwm", 1)) == 0:
334                forced = True
335
336            if int(self.opts.extra.get("conflict_resolve", 1)) == 0:
337                forced = True
338
339            if not forced and snk_conf_res != "any" and src_conf_res != "any" and src_conf_res != snk_conf_res:
340                logging.error("Cannot transfer data, source bucket `%s` uses " +
341                             "%s conflict resolution but sink bucket `%s` uses " +
342                             "%s conflict resolution", source_bucket["name"],
343                             src_conf_res, snk_bucket, snk_conf_res)
344            else:
345                rv = Pump(self.opts, source, sink, source_map, sink_map, self.ctl,
346                          curx).run()
347
348                for k, v in curx.items():
349                    if isinstance(v, int):
350                        self.cur[k] = self.cur.get(k, 0) + v
351
352                logging.debug(" node: %s, done; rv: %s" % (hostname, rv))
353                if self.ctl['rv'] == 0 and rv != 0:
354                    self.ctl['rv'] = rv
355
356            self.queue.task_done()
357
358    def start_workers(self, queue_size):
359        if self.queue:
360            return
361
362        self.queue = Queue.Queue(queue_size)
363
364        threads = [threading.Thread(target=PumpingStation.run_worker,
365                                    name="w" + str(i), args=(self, i))
366                   for i in range(self.opts.threads)]
367        for thread in threads:
368            thread.daemon = True
369            thread.start()
370
371    @staticmethod
372    def find_handler(opts, x, classes):
373        for s in classes:
374            if s.can_handle(opts, x):
375                return s
376        return None
377
378
379class Pump(ProgressReporter):
380    """Moves batches of data from one Source to one Sink."""
381
382    def __init__(self, opts, source, sink, source_map, sink_map, ctl, cur):
383        self.opts = opts
384        self.source = source
385        self.sink = sink
386        self.source_map = source_map
387        self.sink_map = sink_map
388        self.ctl = ctl
389        self.cur = cur # Should be a defaultdict(int); 0 as default value.
390
391    def run(self):
392        future = None
393
394        # TODO: (2) Pump - timeouts when providing/consuming/waiting.
395
396        report = int(self.opts.extra.get("report", 5))
397        report_full = int(self.opts.extra.get("report_full", 2000))
398
399        self.report_init()
400
401        n = 0
402
403        while not self.ctl['stop']:
404            rv_batch, batch = self.source.provide_batch()
405            if rv_batch != 0:
406                return self.done(rv_batch)
407
408            if future:
409                rv = future.wait_until_consumed()
410                if rv != 0:
411                    # TODO: (5) Pump - retry logic on consume error.
412                    return self.done(rv)
413
414                self.cur['tot_sink_batch'] += 1
415                self.cur['tot_sink_msg'] += future.batch.size()
416                self.cur['tot_sink_byte'] += future.batch.bytes
417
418                self.ctl['run_msg'] += future.batch.size()
419                self.ctl['tot_msg'] += future.batch.adjust_size
420
421            if not batch:
422                return self.done(0)
423
424            self.cur['tot_source_batch'] += 1
425            self.cur['tot_source_msg'] += batch.size()
426            self.cur['tot_source_byte'] += batch.bytes
427
428            rv_future, future = self.sink.consume_batch_async(batch)
429            if rv_future != 0:
430                return self.done(rv_future)
431
432            n = n + 1
433            if report_full > 0 and n % report_full == 0:
434                if self.opts.verbose > 0:
435                    sys.stderr.write("\n")
436                logging.info("  progress...")
437                self.report(prefix="  ")
438            elif report > 0 and n % report == 0:
439                sys.stderr.write(self.bar(self.ctl['run_msg'],
440                                          self.ctl['tot_msg']))
441
442        return self.done(0)
443
444    def done(self, rv):
445        self.source.close()
446        self.sink.close()
447
448        logging.debug("  pump (%s->%s) done.", self.source, self.sink)
449        self.report(prefix="  ")
450
451        if (rv == 0 and
452            (self.cur['tot_source_batch'] != self.cur['tot_sink_batch'] or
453             self.cur['tot_source_msg'] != self.cur['tot_sink_msg'] or
454             self.cur['tot_source_byte'] != self.cur['tot_sink_byte'])):
455            return "error: sink missing some source msgs: " + str(self.cur)
456
457        return rv
458
459
460# --------------------------------------------------
461
462class EndPoint(object):
463
464    def __init__(self, opts, spec, source_bucket, source_node,
465                 source_map, sink_map, ctl, cur):
466        self.opts = opts
467        self.spec = spec
468        self.source_bucket = source_bucket
469        self.source_node = source_node
470        self.source_map = source_map
471        self.sink_map = sink_map
472        self.ctl = ctl
473        self.cur = cur
474
475        self.only_key_re = None
476        k = getattr(opts, "key", None)
477        if k:
478            self.only_key_re = re.compile(k)
479
480        self.only_vbucket_id = getattr(opts, "id", None)
481
482    @staticmethod
483    def check_base(opts, spec):
484        k = getattr(opts, "key", None)
485        if k:
486            try:
487                re.compile(k)
488            except:
489                return "error: could not parse key regexp: " + k
490        return 0
491
492    @staticmethod
493    def check_spec(source_bucket, source_node, opts, spec, cur):
494        cur['seqno'] = {}
495        cur['failoverlog'] = {}
496        cur['snapshot'] = {}
497        return 0
498
499    def get_conflict_resolution_type(self):
500        return "any"
501
502    def __repr__(self):
503        return "%s(%s@%s)" % \
504            (self.spec,
505             self.source_bucket.get('name', ''),
506             self.source_node.get('hostname', ''))
507
508    def close(self):
509        pass
510
511    def skip(self, key, vbucket_id):
512        if (self.only_key_re and not re.search(self.only_key_re, key)):
513            logging.warn("skipping msg with key: " + tag_user_data(key))
514            return True
515
516        if (self.only_vbucket_id is not None and
517            self.only_vbucket_id != vbucket_id):
518            logging.warn("skipping msg of vbucket_id: " + str(vbucket_id))
519            return True
520
521        return False
522
523    def get_timestamp(self):
524        #milliseconds with three digits
525        return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
526
527    def add_counter(self, key, val=1):
528        self.cur[key] = self.cur.get(key, 0.0) + val
529
530    def add_start_event(self, conn):
531        return 0
532
533    def add_start_event(self, conn):
534        return 0
535
536class Source(EndPoint):
537    """Base class for all data sources."""
538
539    @staticmethod
540    def can_handle(opts, spec):
541        assert False, "unimplemented"
542
543    @staticmethod
544    def check_base(opts, spec):
545        rv = EndPoint.check_base(opts, spec)
546        if rv != 0:
547            return rv
548        if getattr(opts, "source_vbucket_state", "active") != "active":
549            return ("error: only --source-vbucket-state=active" +
550                    " is supported by this source: %s") % (spec)
551        return 0
552
553    @staticmethod
554    def check(opts, spec):
555        """Subclasses can check preconditions before any pumping starts."""
556        assert False, "unimplemented"
557
558    @staticmethod
559    def provide_design(opts, source_spec, source_bucket, source_map):
560        assert False, "unimplemented"
561
562    @staticmethod
563    def provide_index(opts, source_spec, source_bucket, source_map):
564        return 0, None
565
566    @staticmethod
567    def provide_fts_index(opts, source_spec, source_bucket, source_map):
568        return 0, None
569
570    def provide_batch(self):
571        assert False, "unimplemented"
572
573    @staticmethod
574    def total_msgs(opts, source_bucket, source_node, source_map):
575        return 0, None # Subclasses can return estimate # msgs.
576
577
578class Sink(EndPoint):
579    """Base class for all data sinks."""
580
581    # TODO: (2) Sink handles filtered restore by data.
582
583    def __init__(self, opts, spec, source_bucket, source_node,
584                 source_map, sink_map, ctl, cur):
585        super(Sink, self).__init__(opts, spec, source_bucket, source_node,
586                                   source_map, sink_map, ctl, cur)
587        self.op = None
588
589    @staticmethod
590    def can_handle(opts, spec):
591        assert False, "unimplemented"
592
593    @staticmethod
594    def check_base(opts, spec):
595        rv = EndPoint.check_base(opts, spec)
596        if rv != 0:
597            return rv
598        if getattr(opts, "destination_vbucket_state", "active") != "active":
599            return ("error: only --destination-vbucket-state=active" +
600                    " is supported by this destination: %s") % (spec)
601        if getattr(opts, "destination_operation", None) != None:
602            return ("error: --destination-operation" +
603                    " is not supported by this destination: %s") % (spec)
604        return 0
605
606    @staticmethod
607    def check(opts, spec, source_map):
608        """Subclasses can check preconditions before any pumping starts."""
609        assert False, "unimplemented"
610
611    @staticmethod
612    def consume_design(opts, sink_spec, sink_map,
613                       source_bucket, source_map, source_design):
614        assert False, "unimplemented"
615
616    @staticmethod
617    def consume_index(opts, sink_spec, sink_map,
618                       source_bucket, source_map, source_design):
619        return 0
620
621    @staticmethod
622    def consume_fts_index(opts, sink_spec, sink_map,
623                          source_bucket, source_map, source_design):
624        return 0
625
626    def consume_batch_async(self, batch):
627        """Subclasses should return a SinkBatchFuture."""
628        assert False, "unimplemented"
629
630    @staticmethod
631    def check_source(opts, source_class, source_spec, sink_class, sink_spec):
632        if source_spec == sink_spec:
633            return "error: source and sink must be different;" \
634                " source: " + source_spec + \
635                " sink: " + sink_spec
636        return None
637
638    def operation(self):
639        if not self.op:
640            self.op = getattr(self.opts, "destination_operation", None)
641            if not self.op:
642                self.op = "set"
643                if getattr(self.opts, "add", False):
644                    self.op = "add"
645        return self.op
646
647    def init_worker(self, target):
648        self.worker_go = threading.Event()
649        self.worker_work = None # May be None or (batch, future) tuple.
650        self.worker = threading.Thread(target=target, args=(self,),
651                                       name="s" + threading.currentThread().getName()[1:])
652        self.worker.daemon = True
653        self.worker.start()
654
655    def push_next_batch(self, batch, future):
656        """Push batch/future to worker."""
657        if not self.worker.isAlive():
658            return "error: cannot use a dead worker", None
659
660        self.worker_work = (batch, future)
661        self.worker_go.set()
662        return 0, future
663
664    def pull_next_batch(self):
665        """Worker calls this method to get the next batch/future."""
666        self.worker_go.wait()
667        batch, future = self.worker_work
668        self.worker_work = None
669        self.worker_go.clear()
670        return batch, future
671
672    def future_done(self, future, rv):
673        """Worker calls this method to finish a batch/future."""
674        if rv != 0:
675            logging.error("error: async operation: %s on sink: %s" %
676                          (rv, self))
677        if future:
678            future.done_rv = rv
679            future.done.set()
680
681
682# --------------------------------------------------
683
684class Batch(object):
685    """Holds a batch of data being transfered from source to sink."""
686
687    def __init__(self, source):
688        self.source = source
689        self.msgs = []
690        self.bytes = 0
691        self.adjust_size = 0
692
693    def append(self, msg, num_bytes):
694        self.msgs.append(msg)
695        self.bytes = self.bytes + num_bytes
696
697    def size(self):
698        return len(self.msgs)
699
700    def msg(self, i):
701        return self.msgs[i]
702
703    def group_by_vbucket_id(self, vbuckets_num, rehash=0):
704        """Returns dict of vbucket_id->[msgs] grouped by msg's vbucket_id."""
705        g = defaultdict(list)
706        for msg in self.msgs:
707            cmd, vbucket_id, key = msg[:3]
708            if vbucket_id == 0x0000ffff or rehash == 1:
709                # Special case when the source did not supply a vbucket_id
710                # (such as stdin source), so we calculate it.
711                vbucket_id = ((zlib.crc32(key) >> 16) & 0x7FFF) % vbuckets_num
712                msg = (cmd, vbucket_id) + msg[2:]
713            g[vbucket_id].append(msg)
714        return g
715
716
717class SinkBatchFuture(object):
718    """Future completion of a sink consuming a batch."""
719
720    def __init__(self, sink, batch):
721        self.sink = sink
722        self.batch = batch
723        self.done = threading.Event()
724        self.done_rv = None
725
726    def wait_until_consumed(self):
727        self.done.wait()
728        return self.done_rv
729
730
731# --------------------------------------------------
732
733class StdInSource(Source):
734    """Reads batches from stdin in memcached ascii protocol."""
735
736    def __init__(self, opts, spec, source_bucket, source_node,
737                 source_map, sink_map, ctl, cur):
738        super(StdInSource, self).__init__(opts, spec, source_bucket, source_node,
739                                          source_map, sink_map, ctl, cur)
740        self.f = sys.stdin
741
742    @staticmethod
743    def can_handle(opts, spec):
744        return spec.startswith("stdin:") or spec == "-"
745
746    @staticmethod
747    def check(opts, spec):
748        return 0, {'spec': spec,
749                   'buckets': [{'name': 'stdin:',
750                                'nodes': [{'hostname': 'N/A'}]}] }
751
752    @staticmethod
753    def provide_design(opts, source_spec, source_bucket, source_map):
754        return 0, None
755
756    def provide_batch(self):
757        batch = Batch(self)
758
759        batch_max_size = self.opts.extra['batch_max_size']
760        batch_max_bytes = self.opts.extra['batch_max_bytes']
761
762        vbucket_id = 0x0000ffff
763
764        while (self.f and
765               batch.size() < batch_max_size and
766               batch.bytes < batch_max_bytes):
767            line = self.f.readline()
768            if not line:
769                self.f = None
770                return 0, batch
771
772            parts = line.split(' ')
773            if not parts:
774                return "error: read empty line", None
775            elif parts[0] == 'set' or parts[0] == 'add':
776                if len(parts) != 5:
777                    return "error: length of set/add line: " + line, None
778                cmd = couchbaseConstants.CMD_TAP_MUTATION
779                key = parts[1]
780                flg = int(parts[2])
781                exp = int(parts[3])
782                num = int(parts[4])
783                if num > 0:
784                    val = self.f.read(num)
785                    if len(val) != num:
786                        return "error: value read failed at: " + line, None
787                else:
788                    val = ''
789                end = self.f.read(2) # Read '\r\n'.
790                if len(end) != 2:
791                    return "error: value end read failed at: " + line, None
792
793                if not self.skip(key, vbucket_id):
794                    msg = (cmd, vbucket_id, key, flg, exp, 0, '', val, 0, 0, 0)
795                    batch.append(msg, len(val))
796            elif parts[0] == 'delete':
797                if len(parts) != 2:
798                    return "error: length of delete line: " + line, None
799                cmd = couchbaseConstants.CMD_TAP_DELETE
800                key = parts[1]
801                if not self.skip(key, vbucket_id):
802                    msg = (cmd, vbucket_id, key, 0, 0, 0, '', '', 0, 0, 0)
803                    batch.append(msg, 0)
804            else:
805                return "error: expected set/add/delete but got: " + line, None
806
807        if batch.size() <= 0:
808            return 0, None
809
810        return 0, batch
811
812
813class StdOutSink(Sink):
814    """Emits batches to stdout in memcached ascii protocol."""
815
816    @staticmethod
817    def can_handle(opts, spec):
818        if spec.startswith("stdout:") or spec == "-":
819            opts.threads = 1 # Force 1 thread to not overlap stdout.
820            return True
821        return False
822
823    @staticmethod
824    def check(opts, spec, source_map):
825        return 0, None
826
827    @staticmethod
828    def check_base(opts, spec):
829        if getattr(opts, "destination_vbucket_state", "active") != "active":
830            return ("error: only --destination-vbucket-state=active" +
831                    " is supported by this destination: %s") % (spec)
832
833        op = getattr(opts, "destination_operation", None)
834        if not op in [None, 'set', 'add', 'get']:
835            return ("error: --destination-operation unsupported value: %s" +
836                    "; use set, add, get") % (op)
837
838        # Skip immediate superclass Sink.check_base(),
839        # since StdOutSink can handle different destination operations.
840        return EndPoint.check_base(opts, spec)
841
842    @staticmethod
843    def consume_design(opts, sink_spec, sink_map,
844                       source_bucket, source_map, source_design):
845        if source_design:
846            logging.warn("warning: cannot save bucket design"
847                         " on a stdout destination")
848        return 0
849
850    def consume_batch_async(self, batch):
851        op = self.operation()
852        op_mutate = op in ['set', 'add']
853
854        stdout = sys.stdout
855        msg_visitor = None
856
857        opts_etc = getattr(self.opts, "etc", None)
858        if opts_etc:
859            stdout = opts_etc.get("stdout", sys.stdout)
860            msg_visitor = opts_etc.get("msg_visitor", None)
861
862        mcd_compatible = self.opts.extra.get("mcd_compatible", 1)
863        msg_tuple_format = 0
864        for msg in batch.msgs:
865            if msg_visitor:
866                msg = msg_visitor(msg)
867            if not msg_tuple_format:
868                msg_tuple_format = len(msg)
869            cmd, vbucket_id, key, flg, exp, cas, meta, val = msg[:8]
870            seqno = dtype = nmeta = conf_res = 0
871            if msg_tuple_format > 8:
872                seqno, dtype, nmeta, conf_res = msg[8:]
873            if self.skip(key, vbucket_id):
874                continue
875            if dtype > 2:
876                try:
877                    val = snappy.uncompress(val)
878                except Exception, err:
879                    pass
880            try:
881                if cmd == couchbaseConstants.CMD_TAP_MUTATION or \
882                   cmd == couchbaseConstants.CMD_DCP_MUTATION:
883                    if op_mutate:
884                        # <op> <key> <flags> <exptime> <bytes> [noreply]\r\n
885                        if mcd_compatible:
886                            stdout.write("%s %s %s %s %s\r\n" %
887                                         (op, key, flg, exp, len(val)))
888                        else:
889                            stdout.write("%s %s %s %s %s %s %s %s\r\n" %
890                                         (op, key, flg, exp, len(val), seqno, dtype, conf_res))
891                        stdout.write(val)
892                        stdout.write("\r\n")
893                    elif op == 'get':
894                        stdout.write("get %s\r\n" % (key))
895                elif cmd == couchbaseConstants.CMD_TAP_DELETE or \
896                     cmd == couchbaseConstants.CMD_DCP_DELETE:
897                    if op_mutate:
898                        stdout.write("delete %s\r\n" % (key))
899                elif cmd == couchbaseConstants.CMD_GET:
900                    stdout.write("get %s\r\n" % (key))
901                else:
902                    return "error: StdOutSink - unknown cmd: " + str(cmd), None
903            except IOError:
904                return "error: could not write to stdout", None
905
906        stdout.flush()
907        future = SinkBatchFuture(self, batch)
908        self.future_done(future, 0)
909        return 0, future
910
911
912# --------------------------------------------------
913
914CMD_STR = {
915    couchbaseConstants.CMD_TAP_CONNECT: "TAP_CONNECT",
916    couchbaseConstants.CMD_TAP_MUTATION: "TAP_MUTATION",
917    couchbaseConstants.CMD_TAP_DELETE: "TAP_DELETE",
918    couchbaseConstants.CMD_TAP_FLUSH: "TAP_FLUSH",
919    couchbaseConstants.CMD_TAP_OPAQUE: "TAP_OPAQUE",
920    couchbaseConstants.CMD_TAP_VBUCKET_SET: "TAP_VBUCKET_SET",
921    couchbaseConstants.CMD_TAP_CHECKPOINT_START: "TAP_CHECKPOINT_START",
922    couchbaseConstants.CMD_TAP_CHECKPOINT_END: "TAP_CHECKPOINT_END",
923    couchbaseConstants.CMD_NOOP: "NOOP"
924}
925
926def get_username(username):
927    return username or os.environ.get('CB_REST_USERNAME', '')
928
929def get_password(password):
930    return password or os.environ.get('CB_REST_PASSWORD', '')
931
932def parse_spec(opts, spec, port):
933    """Parse host, port, username, password, path from opts and spec."""
934
935    # Example spec: http://Administrator:password@HOST:8091
936    p = urlparse.urlparse(spec)
937
938    # Example netloc: Administrator:password@HOST:8091
939    #ParseResult tuple(scheme, netloc, path, params, query, fragment)
940    netloc = p[1]
941
942    if not netloc: # When urlparse() can't parse non-http URI's.
943        netloc = spec.split('://')[-1].split('/')[0]
944
945    pair = netloc.split('@') # [ "user:pwsd", "host:port" ].
946    host = p.hostname
947    port = p.port
948    try:
949       val = int(port)
950    except ValueError:
951       logging.warn("\"" + port + "\" is not int, reset it to default port number")
952       port = 8091
953
954    username = get_username(opts.username)
955    password = get_password(opts.password)
956    if len(pair) > 1:
957        username = username or (pair[0] + ':').split(':')[0]
958        password = password or (pair[0] + ':').split(':')[1]
959
960    return host, port, username, password, p[2]
961
962def rest_request(host, port, user, pswd, ssl, path, method='GET', body='', reason='', headers=None):
963    if reason:
964        reason = "; reason: %s" % (reason)
965    logging.debug("rest_request: %s@%s:%s%s%s" % (tag_user_data(user), host, port, path, reason))
966    if ssl:
967        if port not in [couchbaseConstants.SSL_REST_PORT, couchbaseConstants.SSL_QUERY_PORT]:
968            return ("error: invalid port %s used when ssl option is specified") % port, None, None
969        conn = httplib.HTTPSConnection(host, port)
970    else:
971        conn = httplib.HTTPConnection(host, port)
972    try:
973        header = rest_headers(user, pswd, headers)
974        conn.request(method, path, body, header)
975        resp = conn.getresponse()
976    except Exception, e:
977        return ("error: could not access REST API: %s:%s%s" +
978                "; please check source URL, server status, username (-u) and password (-p)" +
979                "; exception: %s%s") % \
980                (host, port, path, e, reason), None, None
981
982    if resp.status in [200, 201, 202, 204, 302]:
983        return None, conn, resp.read()
984
985    conn.close()
986    if resp.status == 401:
987        return ("error: unable to access REST API: %s:%s%s" +
988                "; please check source URL, server status, username (-u) and password (-p)%s") % \
989                (host, port, path, reason), None, None
990
991    return ("error: unable to access REST API: %s:%s%s" +
992            "; please check source URL, server status, username (-u) and password (-p)" +
993            "; response: %s%s") % \
994            (host, port, path, resp.status, reason), None, None
995
996def rest_headers(user, pswd, headers=None):
997    if not headers:
998        headers = {'Content-Type': 'application/json'}
999    if user:
1000        auth = 'Basic ' + \
1001            string.strip(base64.encodestring(user + ':' + (pswd or '')))
1002        headers['Authorization'] = auth
1003    return headers
1004
1005def rest_request_json(host, port, user, pswd, ssl, path, reason=''):
1006    err, conn, rest_json = rest_request(host, port, user, pswd, ssl, path,
1007                                        reason=reason)
1008    if err:
1009        return err, None, None
1010    if conn:
1011        conn.close()
1012    try:
1013        return None, rest_json, json.loads(rest_json)
1014    except ValueError, e:
1015        return ("error: could not decode JSON from REST API: %s:%s%s" +
1016                "; exception: %s" +
1017                "; please check URL, username (-u) and password (-p)") % \
1018                (host, port, path, e), None, None
1019
1020def rest_couchbase(opts, spec):
1021    spec = spec.replace('couchbase://', 'http://')
1022    spec_parts = parse_spec(opts, spec, 8091)
1023    rest = ClusterManager(spec, opts.username, opts.password, opts.ssl, False,
1024                          None, False)
1025
1026    result, errors = rest.list_buckets(True)
1027    if errors:
1028        return errors[0], None
1029
1030    buckets = []
1031    for bucket in result:
1032        if bucket["bucketType"] in ["membase", "couchbase", "ephemeral"]:
1033            buckets.append(bucket)
1034
1035
1036    return 0, {'spec': spec, 'buckets': buckets, 'spec_parts': parse_spec(opts, spec, 8091)}
1037
1038def filter_bucket_nodes(bucket, spec_parts):
1039    host, port = spec_parts[:2]
1040    if host in ['localhost', '127.0.0.1']:
1041        host = get_ip()
1042    # Convert from raw IPv6
1043    if ':' in host:
1044        host_port = '[' + host + ']:' + str(port)
1045    else:
1046        host_port = host + ':' + str(port)
1047    return filter(lambda n: n.get('hostname') == host_port,
1048                  bucket['nodes'])
1049
1050def get_ip():
1051    ip = None
1052    for fname in ['/opt/couchbase/var/lib/couchbase/ip_start',
1053                  '/opt/couchbase/var/lib/couchbase/ip',
1054                  '../var/lib/couchbase/ip_start',
1055                  '../var/lib/couchbase/ip']:
1056        try:
1057            f = open(fname, 'r')
1058            ip = string.strip(f.read())
1059            f.close()
1060            if ip and len(ip):
1061                if ip.find('@'):
1062                    ip = ip.split('@')[1]
1063                break
1064        except:
1065            pass
1066    if not ip or not len(ip):
1067        ip = '127.0.0.1'
1068    return ip
1069
1070def find_source_bucket_name(opts, source_map):
1071    """If the caller didn't specify a bucket_source and
1072       there's only one bucket in the source_map, use that."""
1073    source_bucket = getattr(opts, "bucket_source", None)
1074    if (not source_bucket and
1075        source_map and
1076        source_map['buckets'] and
1077        len(source_map['buckets']) == 1):
1078        source_bucket = source_map['buckets'][0]['name']
1079    if not source_bucket:
1080        return "error: please specify a bucket_source", None
1081    logging.debug("source_bucket: " + source_bucket)
1082    return 0, source_bucket
1083
1084def find_sink_bucket_name(opts, source_bucket):
1085    """Default bucket_destination to the same as bucket_source."""
1086    sink_bucket = getattr(opts, "bucket_destination", None) or source_bucket
1087    if not sink_bucket:
1088        return "error: please specify a bucket_destination", None
1089    logging.debug("sink_bucket: " + sink_bucket)
1090    return 0, sink_bucket
1091
1092def mkdirs(targetpath):
1093    upperdirs = os.path.dirname(targetpath)
1094    if upperdirs and not os.path.exists(upperdirs):
1095        try:
1096            os.makedirs(upperdirs)
1097        except:
1098            return "Cannot create upper directories for file:%s" % targetpath
1099    return 0
1100
1101def hostport(hoststring, port=11210):
1102    if hoststring.startswith('['):
1103        matches = re.match(r'^\[([^\]]+)\](:(\d+))?$', hoststring)
1104    else:
1105        matches = re.match(r'^([^:]+)(:(\d+))?$', hoststring)
1106    if matches:
1107        # The host is the first group
1108        host = matches.group(1)
1109        # Optional port is the 3rd group
1110        if matches.group(3):
1111            port = int(matches.group(3))
1112    return host, port
1113
1114def get_mcd_conn(host, port, username, password, bucket):
1115    conn = cb_bin_client.MemcachedClient(host, port)
1116    if not conn:
1117        return "error: could not connect to memcached: " + \
1118            host + ":" + str(port), None
1119
1120    try:
1121        conn.sasl_auth_plain(username, password)
1122    except EOFError, e:
1123        return "error: SASL auth error: %s:%s, %s" % (host, port, e), None
1124    except cb_bin_client.MemcachedError, e:
1125        return "error: SASL auth failed: %s:%s, %s" % (host, port, e), None
1126    except socket.error, e:
1127        return "error: SASL auth socket error: %s:%s, %s" % (host, port, e), None
1128
1129    try:
1130        conn.helo([couchbaseConstants.HELO_XATTR, couchbaseConstants.HELO_XERROR])
1131    except EOFError, e:
1132        return "error: HELO error: %s:%s, %s" % (host, port, e), None
1133    except cb_bin_client.MemcachedError, e:
1134        return "error: HELO failed: %s:%s, %s" % (host, port, e), None
1135    except socket.error, e:
1136        return "error: HELO socket error: %s:%s, %s" % (host, port, e), None
1137
1138    if bucket:
1139        try:
1140            conn.bucket_select(bucket)
1141        except EOFError, e:
1142            return "error: Bucket select error: %s:%s %s, %s" % (host, port, bucket, e), None
1143        except cb_bin_client.MemcachedError, e:
1144            return "error: Bucket select failed: %s:%s %s, %s" % (host, port, bucket, e), None
1145        except socket.error, e:
1146            return "error: Bucket select socket error: %s:%s %s, %s" % (host, port, bucket, e), None
1147
1148    return 0, conn
1149