xref: /4.6.0/couchbase-cli/pump.py (revision 38e3ecf9)
1#!/usr/bin/env python
2
3import os
4import base64
5import copy
6import httplib
7import logging
8import re
9import Queue
10import json
11import string
12import sys
13import threading
14import time
15import urllib
16import urlparse
17import zlib
18import platform
19import subprocess
20
21import couchbaseConstants
22from collections import defaultdict
23import cbsnappy as snappy
24
25# TODO: (1) optionally log into backup directory
26
27LOGGING_FORMAT = '%(asctime)s: %(threadName)s %(message)s'
28
29NA = 'N/A'
30
31class ProgressReporter(object):
32    """Mixin to report progress"""
33
34    def report_init(self):
35        self.beg_time = time.time()
36        self.prev_time = self.beg_time
37        self.prev = defaultdict(int)
38
39    def report(self, prefix="", emit=None):
40        if not emit:
41            emit = logging.info
42
43        if getattr(self, "source", None):
44            emit(prefix + "source : %s" % (self.source))
45        if getattr(self, "sink", None):
46            emit(prefix + "sink   : %s" % (self.sink))
47
48        cur_time = time.time()
49        delta = cur_time - self.prev_time
50        c, p = self.cur, self.prev
51        x = sorted([k for k in c.iterkeys() if "_sink_" in k])
52
53        width_k = max([5] + [len(k.replace("tot_sink_", "")) for k in x])
54        width_v = max([20] + [len(str(c[k])) for k in x])
55        width_d = max([10] + [len(str(c[k] - p[k])) for k in x])
56        width_s = max([10] + [len("%0.1f" % ((c[k] - p[k]) / delta)) for k in x])
57        emit(prefix + " %s : %s | %s | %s"
58             % (string.ljust("", width_k),
59                string.rjust("total", width_v),
60                string.rjust("last", width_d),
61                string.rjust("per sec", width_s)))
62        verbose_set = ["tot_sink_batch", "tot_sink_msg"]
63        for k in x:
64            if k not in verbose_set or self.opts.verbose > 0:
65                emit(prefix + " %s : %s | %s | %s"
66                 % (string.ljust(k.replace("tot_sink_", ""), width_k),
67                    string.rjust(str(c[k]), width_v),
68                    string.rjust(str(c[k] - p[k]), width_d),
69                    string.rjust("%0.1f" % ((c[k] - p[k]) / delta), width_s)))
70        self.prev_time = cur_time
71        self.prev = copy.copy(c)
72
73    def bar(self, current, total):
74        if not total:
75            return '.'
76        if sys.platform.lower().startswith('win'):
77            cr = "\r"
78        else:
79            cr = chr(27) + "[A\n"
80        pct = float(current) / total
81        max_hash = 20
82        num_hash = int(round(pct * max_hash))
83        return ("  [%s%s] %0.1f%% (%s/estimated %s msgs)%s" %
84                ('#' * num_hash, ' ' * (max_hash - num_hash),
85                 100.0 * pct, current, total, cr))
86
87class PumpingStation(ProgressReporter):
88    """Queues and watchdogs multiple pumps across concurrent workers."""
89
90    def __init__(self, opts, source_class, source_spec, sink_class, sink_spec):
91        self.opts = opts
92        self.source_class = source_class
93        self.source_spec = source_spec
94        self.sink_class = sink_class
95        self.sink_spec = sink_spec
96        self.queue = None
97        tmstamp = time.strftime("%Y-%m-%dT%H%M%SZ", time.gmtime())
98        self.ctl = { 'stop': False,
99                     'rv': 0,
100                     'new_session': True,
101                     'new_timestamp': tmstamp}
102        self.cur = defaultdict(int)
103
104    def run(self):
105        # TODO: (6) PumpingStation - monitor source for topology changes.
106        # TODO: (4) PumpingStation - retry on err N times, M times / server.
107        # TODO: (2) PumpingStation - track checksum in backup, for later restore.
108
109        rv, source_map, sink_map = self.check_endpoints()
110        if rv != 0:
111            return rv
112
113        if self.opts.dry_run:
114            sys.stderr.write("done, but no data written due to dry-run\n")
115            return 0
116
117        source_buckets = self.filter_source_buckets(source_map)
118        if not source_buckets:
119            bucket_source = getattr(self.opts, "bucket_source", None)
120            if bucket_source:
121                return ("error: there is no bucket: %s at source: %s" %
122                        (bucket_source, self.source_spec))
123            else:
124                return ("error: no transferrable buckets at source: %s" %
125                        (self.source_spec))
126
127        for source_bucket in sorted(source_buckets,
128                                    key=lambda b: b['name']):
129            logging.info("bucket: " + source_bucket['name'])
130
131            if not self.opts.extra.get("design_doc_only", 0):
132                rv = self.transfer_bucket_msgs(source_bucket, source_map, sink_map)
133                if rv != 0:
134                    return rv
135            else:
136                sys.stderr.write("transfer design doc only. bucket msgs will be skipped.\n")
137
138            if not self.opts.extra.get("data_only", 0):
139                self.transfer_bucket_design(source_bucket, source_map, sink_map)
140                self.transfer_bucket_index(source_bucket, source_map, sink_map)
141            else:
142                sys.stderr.write("transfer data only. bucket design docs and index meta will be skipped.\n")
143
144            # TODO: (5) PumpingStation - validate bucket transfers.
145
146        # TODO: (4) PumpingStation - validate source/sink maps were stable.
147
148        sys.stderr.write("done\n")
149        return 0
150
151    def check_endpoints(self):
152        logging.debug("source_class: %s", self.source_class)
153        rv = self.source_class.check_base(self.opts, self.source_spec)
154        if rv != 0:
155            return rv, None, None
156        rv, source_map = self.source_class.check(self.opts, self.source_spec)
157        if rv != 0:
158            return rv, None, None
159
160        logging.debug("sink_class: %s", self.sink_class)
161        rv = self.sink_class.check_base(self.opts, self.sink_spec)
162        if rv != 0:
163            return rv, None, None
164        rv, sink_map = self.sink_class.check(self.opts, self.sink_spec, source_map)
165        if rv != 0:
166            return rv, None, None
167
168        return rv, source_map, sink_map
169
170    def filter_source_buckets(self, source_map):
171        """Filter the source_buckets if a bucket_source was specified."""
172        source_buckets = source_map['buckets']
173        logging.debug("source_buckets: " +
174                      ",".join([n['name'] for n in source_buckets]))
175
176        bucket_source = getattr(self.opts, "bucket_source", None)
177        if bucket_source:
178            logging.debug("bucket_source: " + bucket_source)
179            source_buckets = [b for b in source_buckets
180                              if b['name'] == bucket_source]
181            logging.debug("source_buckets filtered: " +
182                          ",".join([n['name'] for n in source_buckets]))
183        return source_buckets
184
185    def filter_source_nodes(self, source_bucket, source_map):
186        """Filter the source_bucket's nodes if single_node was specified."""
187        if getattr(self.opts, "single_node", None):
188            if not source_map.get('spec_parts'):
189                return ("error: no single_node from source: %s" +
190                        "; the source may not support the --single-node flag") % \
191                        (self.source_spec)
192            source_nodes = filter_bucket_nodes(source_bucket,
193                                               source_map.get('spec_parts'))
194        else:
195            source_nodes = source_bucket['nodes']
196
197        logging.debug(" source_nodes: " + ",".join([n.get('hostname', NA)
198                                                    for n in source_nodes]))
199        return source_nodes
200
201    def transfer_bucket_msgs(self, source_bucket, source_map, sink_map):
202        source_nodes = self.filter_source_nodes(source_bucket, source_map)
203
204        # Transfer bucket msgs with a Pump per source server.
205        self.start_workers(len(source_nodes))
206        self.report_init()
207
208        self.ctl['run_msg'] = 0
209        self.ctl['tot_msg'] = 0
210
211        for source_node in sorted(source_nodes,
212                                  key=lambda n: n.get('hostname', NA)):
213            logging.debug(" enqueueing node: " +
214                          source_node.get('hostname', NA))
215            self.queue.put((source_bucket, source_node, source_map, sink_map))
216
217            rv, tot = self.source_class.total_msgs(self.opts,
218                                                   source_bucket,
219                                                   source_node,
220                                                   source_map)
221            if rv != 0:
222                return rv
223            if tot:
224                self.ctl['tot_msg'] += tot
225
226        # Don't use queue.join() as it eats Ctrl-C's.
227        s = 0.05
228        while self.queue.unfinished_tasks:
229            time.sleep(s)
230            s = min(1.0, s + 0.01)
231
232        rv = self.ctl['rv']
233        if rv != 0:
234            return rv
235
236        time.sleep(0.01) # Allows threads to update counters.
237
238        sys.stderr.write(self.bar(self.ctl['run_msg'],
239                                  self.ctl['tot_msg']) + "\n")
240        sys.stderr.write("bucket: " + source_bucket['name'] +
241                         ", msgs transferred...\n")
242        def emit(msg):
243            sys.stderr.write(msg + "\n")
244        self.report(emit=emit)
245
246        return 0
247
248    def transfer_bucket_design(self, source_bucket, source_map, sink_map):
249        """Transfer bucket design (e.g., design docs, views)."""
250        rv, source_design = \
251            self.source_class.provide_design(self.opts, self.source_spec,
252                                             source_bucket, source_map)
253        if rv == 0:
254            if source_design:
255                sources = source_design if isinstance(source_design, list) else [source_design]
256                for source_design in sources:
257                    rv = self.sink_class.consume_design(self.opts,
258                                                self.sink_spec, sink_map,
259                                                source_bucket, source_map,
260                                                source_design)
261        return rv
262
263    def transfer_bucket_index(self, source_bucket, source_map, sink_map):
264        """Transfer bucket index meta."""
265        rv, source_design = \
266            self.source_class.provide_index(self.opts, self.source_spec,
267                                             source_bucket, source_map)
268        if rv == 0:
269            if source_design:
270                rv = self.sink_class.consume_index(self.opts,
271                                                self.sink_spec, sink_map,
272                                                source_bucket, source_map,
273                                                source_design)
274        return rv
275
276    @staticmethod
277    def run_worker(self, thread_index):
278        while not self.ctl['stop']:
279            source_bucket, source_node, source_map, sink_map = \
280                self.queue.get()
281            hostname = source_node.get('hostname', NA)
282            logging.debug(" node: %s" % (hostname))
283
284            curx = defaultdict(int)
285            self.source_class.check_spec(source_bucket,
286                                         source_node,
287                                         self.opts,
288                                         self.source_spec,
289                                         curx)
290            self.sink_class.check_spec(source_bucket,
291                                       source_node,
292                                       self.opts,
293                                       self.sink_spec,
294                                       curx)
295
296            source = self.source_class(self.opts, self.source_spec, source_bucket,
297                                       source_node, source_map, sink_map, self.ctl,
298                                       curx)
299            sink = self.sink_class(self.opts, self.sink_spec, source_bucket,
300                                   source_node, source_map, sink_map, self.ctl,
301                                   curx)
302
303            src_conf_res = source.get_conflict_resolution_type()
304            snk_conf_res = sink.get_conflict_resolution_type()
305            _, snk_bucket = find_sink_bucket_name(self.opts, source_bucket["name"])
306
307            forced = False
308            if int(self.opts.extra.get("try_xwm", 1)) == 0:
309                forced = True
310
311            if int(self.opts.extra.get("conflict_resolve", 1)) == 0:
312                forced = True
313
314            if not forced and snk_conf_res != "any" and src_conf_res != "any" and src_conf_res != snk_conf_res:
315                logging.error("Cannot transfer data, source bucket `%s` uses " +
316                             "%s conflict resolution but sink bucket `%s` uses " +
317                             "%s conflict resolution", source_bucket["name"],
318                             src_conf_res, snk_bucket, snk_conf_res)
319            else:
320                rv = Pump(self.opts, source, sink, source_map, sink_map, self.ctl,
321                          curx).run()
322
323                for k, v in curx.items():
324                    if isinstance(v, int):
325                        self.cur[k] = self.cur.get(k, 0) + v
326
327                logging.debug(" node: %s, done; rv: %s" % (hostname, rv))
328                if self.ctl['rv'] == 0 and rv != 0:
329                    self.ctl['rv'] = rv
330
331            self.queue.task_done()
332
333    def start_workers(self, queue_size):
334        if self.queue:
335            return
336
337        self.queue = Queue.Queue(queue_size)
338
339        threads = [threading.Thread(target=PumpingStation.run_worker,
340                                    name="w" + str(i), args=(self, i))
341                   for i in range(self.opts.threads)]
342        for thread in threads:
343            thread.daemon = True
344            thread.start()
345
346    @staticmethod
347    def find_handler(opts, x, classes):
348        for s in classes:
349            if s.can_handle(opts, x):
350                return s
351        return None
352
353
354class Pump(ProgressReporter):
355    """Moves batches of data from one Source to one Sink."""
356
357    def __init__(self, opts, source, sink, source_map, sink_map, ctl, cur):
358        self.opts = opts
359        self.source = source
360        self.sink = sink
361        self.source_map = source_map
362        self.sink_map = sink_map
363        self.ctl = ctl
364        self.cur = cur # Should be a defaultdict(int); 0 as default value.
365
366    def run(self):
367        future = None
368
369        # TODO: (2) Pump - timeouts when providing/consuming/waiting.
370
371        report = int(self.opts.extra.get("report", 5))
372        report_full = int(self.opts.extra.get("report_full", 2000))
373
374        self.report_init()
375
376        n = 0
377
378        while not self.ctl['stop']:
379            rv_batch, batch = self.source.provide_batch()
380            if rv_batch != 0:
381                return self.done(rv_batch)
382
383            if future:
384                rv = future.wait_until_consumed()
385                if rv != 0:
386                    # TODO: (5) Pump - retry logic on consume error.
387                    return self.done(rv)
388
389                self.cur['tot_sink_batch'] += 1
390                self.cur['tot_sink_msg'] += future.batch.size()
391                self.cur['tot_sink_byte'] += future.batch.bytes
392
393                self.ctl['run_msg'] += future.batch.size()
394                self.ctl['tot_msg'] += future.batch.adjust_size
395
396            if not batch:
397                return self.done(0)
398
399            self.cur['tot_source_batch'] += 1
400            self.cur['tot_source_msg'] += batch.size()
401            self.cur['tot_source_byte'] += batch.bytes
402
403            rv_future, future = self.sink.consume_batch_async(batch)
404            if rv_future != 0:
405                return self.done(rv_future)
406
407            n = n + 1
408            if report_full > 0 and n % report_full == 0:
409                if self.opts.verbose > 0:
410                    sys.stderr.write("\n")
411                logging.info("  progress...")
412                self.report(prefix="  ")
413            elif report > 0 and n % report == 0:
414                sys.stderr.write(self.bar(self.ctl['run_msg'],
415                                          self.ctl['tot_msg']))
416
417        return self.done(0)
418
419    def done(self, rv):
420        self.source.close()
421        self.sink.close()
422
423        logging.debug("  pump (%s->%s) done.", self.source, self.sink)
424        self.report(prefix="  ")
425
426        if (rv == 0 and
427            (self.cur['tot_source_batch'] != self.cur['tot_sink_batch'] or
428             self.cur['tot_source_msg'] != self.cur['tot_sink_msg'] or
429             self.cur['tot_source_byte'] != self.cur['tot_sink_byte'])):
430            return "error: sink missing some source msgs: " + str(self.cur)
431
432        return rv
433
434
435# --------------------------------------------------
436
437class EndPoint(object):
438
439    def __init__(self, opts, spec, source_bucket, source_node,
440                 source_map, sink_map, ctl, cur):
441        self.opts = opts
442        self.spec = spec
443        self.source_bucket = source_bucket
444        self.source_node = source_node
445        self.source_map = source_map
446        self.sink_map = sink_map
447        self.ctl = ctl
448        self.cur = cur
449
450        self.only_key_re = None
451        k = getattr(opts, "key", None)
452        if k:
453            self.only_key_re = re.compile(k)
454
455        self.only_vbucket_id = getattr(opts, "id", None)
456
457    @staticmethod
458    def check_base(opts, spec):
459        k = getattr(opts, "key", None)
460        if k:
461            try:
462                re.compile(k)
463            except:
464                return "error: could not parse key regexp: " + k
465        return 0
466
467    @staticmethod
468    def check_spec(source_bucket, source_node, opts, spec, cur):
469        cur['seqno'] = {}
470        cur['failoverlog'] = {}
471        cur['snapshot'] = {}
472        return 0
473
474    def get_conflict_resolution_type(self):
475        return "any"
476
477    def __repr__(self):
478        return "%s(%s@%s)" % \
479            (self.spec,
480             self.source_bucket.get('name', ''),
481             self.source_node.get('hostname', ''))
482
483    def close(self):
484        pass
485
486    def skip(self, key, vbucket_id):
487        if (self.only_key_re and not re.search(self.only_key_re, key)):
488            logging.warn("skipping msg with key: " + str(key))
489            return True
490
491        if (self.only_vbucket_id is not None and
492            self.only_vbucket_id != vbucket_id):
493            logging.warn("skipping msg of vbucket_id: " + str(vbucket_id))
494            return True
495
496        return False
497
498    def get_timestamp(self):
499        #milliseconds with three digits
500        return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
501
502    def add_counter(self, key, val=1):
503        self.cur[key] = self.cur.get(key, 0.0) + val
504
505    def add_start_event(self, conn):
506        return 0
507
508    def add_start_event(self, conn):
509        return 0
510
511class Source(EndPoint):
512    """Base class for all data sources."""
513
514    @staticmethod
515    def can_handle(opts, spec):
516        assert False, "unimplemented"
517
518    @staticmethod
519    def check_base(opts, spec):
520        rv = EndPoint.check_base(opts, spec)
521        if rv != 0:
522            return rv
523        if getattr(opts, "source_vbucket_state", "active") != "active":
524            return ("error: only --source-vbucket-state=active" +
525                    " is supported by this source: %s") % (spec)
526        return 0
527
528    @staticmethod
529    def check(opts, spec):
530        """Subclasses can check preconditions before any pumping starts."""
531        assert False, "unimplemented"
532
533    @staticmethod
534    def provide_design(opts, source_spec, source_bucket, source_map):
535        assert False, "unimplemented"
536
537    @staticmethod
538    def provide_index(opts, source_spec, source_bucket, source_map):
539        return 0, None
540
541    def provide_batch(self):
542        assert False, "unimplemented"
543
544    @staticmethod
545    def total_msgs(opts, source_bucket, source_node, source_map):
546        return 0, None # Subclasses can return estimate # msgs.
547
548
549class Sink(EndPoint):
550    """Base class for all data sinks."""
551
552    # TODO: (2) Sink handles filtered restore by data.
553
554    def __init__(self, opts, spec, source_bucket, source_node,
555                 source_map, sink_map, ctl, cur):
556        super(Sink, self).__init__(opts, spec, source_bucket, source_node,
557                                   source_map, sink_map, ctl, cur)
558        self.op = None
559
560    @staticmethod
561    def can_handle(opts, spec):
562        assert False, "unimplemented"
563
564    @staticmethod
565    def check_base(opts, spec):
566        rv = EndPoint.check_base(opts, spec)
567        if rv != 0:
568            return rv
569        if getattr(opts, "destination_vbucket_state", "active") != "active":
570            return ("error: only --destination-vbucket-state=active" +
571                    " is supported by this destination: %s") % (spec)
572        if getattr(opts, "destination_operation", None) != None:
573            return ("error: --destination-operation" +
574                    " is not supported by this destination: %s") % (spec)
575        return 0
576
577    @staticmethod
578    def check(opts, spec, source_map):
579        """Subclasses can check preconditions before any pumping starts."""
580        assert False, "unimplemented"
581
582    @staticmethod
583    def consume_design(opts, sink_spec, sink_map,
584                       source_bucket, source_map, source_design):
585        assert False, "unimplemented"
586
587    @staticmethod
588    def consume_index(opts, sink_spec, sink_map,
589                       source_bucket, source_map, source_design):
590        return 0
591
592    def consume_batch_async(self, batch):
593        """Subclasses should return a SinkBatchFuture."""
594        assert False, "unimplemented"
595
596    @staticmethod
597    def check_source(opts, source_class, source_spec, sink_class, sink_spec):
598        if source_spec == sink_spec:
599            return "error: source and sink must be different;" \
600                " source: " + source_spec + \
601                " sink: " + sink_spec
602        return None
603
604    def operation(self):
605        if not self.op:
606            self.op = getattr(self.opts, "destination_operation", None)
607            if not self.op:
608                self.op = "set"
609                if getattr(self.opts, "add", False):
610                    self.op = "add"
611        return self.op
612
613    def init_worker(self, target):
614        self.worker_go = threading.Event()
615        self.worker_work = None # May be None or (batch, future) tuple.
616        self.worker = threading.Thread(target=target, args=(self,),
617                                       name="s" + threading.currentThread().getName()[1:])
618        self.worker.daemon = True
619        self.worker.start()
620
621    def push_next_batch(self, batch, future):
622        """Push batch/future to worker."""
623        if not self.worker.isAlive():
624            return "error: cannot use a dead worker", None
625
626        self.worker_work = (batch, future)
627        self.worker_go.set()
628        return 0, future
629
630    def pull_next_batch(self):
631        """Worker calls this method to get the next batch/future."""
632        self.worker_go.wait()
633        batch, future = self.worker_work
634        self.worker_work = None
635        self.worker_go.clear()
636        return batch, future
637
638    def future_done(self, future, rv):
639        """Worker calls this method to finish a batch/future."""
640        if rv != 0:
641            logging.error("error: async operation: %s on sink: %s" %
642                          (rv, self))
643        if future:
644            future.done_rv = rv
645            future.done.set()
646
647
648# --------------------------------------------------
649
650class Batch(object):
651    """Holds a batch of data being transfered from source to sink."""
652
653    def __init__(self, source):
654        self.source = source
655        self.msgs = []
656        self.bytes = 0
657        self.adjust_size = 0
658
659    def append(self, msg, num_bytes):
660        self.msgs.append(msg)
661        self.bytes = self.bytes + num_bytes
662
663    def size(self):
664        return len(self.msgs)
665
666    def msg(self, i):
667        return self.msgs[i]
668
669    def group_by_vbucket_id(self, vbuckets_num, rehash=0):
670        """Returns dict of vbucket_id->[msgs] grouped by msg's vbucket_id."""
671        g = defaultdict(list)
672        for msg in self.msgs:
673            cmd, vbucket_id, key = msg[:3]
674            if vbucket_id == 0x0000ffff or rehash == 1:
675                # Special case when the source did not supply a vbucket_id
676                # (such as stdin source), so we calculate it.
677                vbucket_id = (zlib.crc32(key) >> 16) & (vbuckets_num - 1)
678                msg = (cmd, vbucket_id) + msg[2:]
679            g[vbucket_id].append(msg)
680        return g
681
682
683class SinkBatchFuture(object):
684    """Future completion of a sink consuming a batch."""
685
686    def __init__(self, sink, batch):
687        self.sink = sink
688        self.batch = batch
689        self.done = threading.Event()
690        self.done_rv = None
691
692    def wait_until_consumed(self):
693        self.done.wait()
694        return self.done_rv
695
696
697# --------------------------------------------------
698
699class StdInSource(Source):
700    """Reads batches from stdin in memcached ascii protocol."""
701
702    def __init__(self, opts, spec, source_bucket, source_node,
703                 source_map, sink_map, ctl, cur):
704        super(StdInSource, self).__init__(opts, spec, source_bucket, source_node,
705                                          source_map, sink_map, ctl, cur)
706        self.f = sys.stdin
707
708    @staticmethod
709    def can_handle(opts, spec):
710        return spec.startswith("stdin:") or spec == "-"
711
712    @staticmethod
713    def check(opts, spec):
714        return 0, {'spec': spec,
715                   'buckets': [{'name': 'stdin:',
716                                'nodes': [{'hostname': 'N/A'}]}] }
717
718    @staticmethod
719    def provide_design(opts, source_spec, source_bucket, source_map):
720        return 0, None
721
722    def provide_batch(self):
723        batch = Batch(self)
724
725        batch_max_size = self.opts.extra['batch_max_size']
726        batch_max_bytes = self.opts.extra['batch_max_bytes']
727
728        vbucket_id = 0x0000ffff
729
730        while (self.f and
731               batch.size() < batch_max_size and
732               batch.bytes < batch_max_bytes):
733            line = self.f.readline()
734            if not line:
735                self.f = None
736                return 0, batch
737
738            parts = line.split(' ')
739            if not parts:
740                return "error: read empty line", None
741            elif parts[0] == 'set' or parts[0] == 'add':
742                if len(parts) != 5:
743                    return "error: length of set/add line: " + line, None
744                cmd = couchbaseConstants.CMD_TAP_MUTATION
745                key = parts[1]
746                flg = int(parts[2])
747                exp = int(parts[3])
748                num = int(parts[4])
749                if num > 0:
750                    val = self.f.read(num)
751                    if len(val) != num:
752                        return "error: value read failed at: " + line, None
753                else:
754                    val = ''
755                end = self.f.read(2) # Read '\r\n'.
756                if len(end) != 2:
757                    return "error: value end read failed at: " + line, None
758
759                if not self.skip(key, vbucket_id):
760                    msg = (cmd, vbucket_id, key, flg, exp, 0, '', val, 0, 0, 0)
761                    batch.append(msg, len(val))
762            elif parts[0] == 'delete':
763                if len(parts) != 2:
764                    return "error: length of delete line: " + line, None
765                cmd = couchbaseConstants.CMD_TAP_DELETE
766                key = parts[1]
767                if not self.skip(key, vbucket_id):
768                    msg = (cmd, vbucket_id, key, 0, 0, 0, '', '', 0, 0, 0)
769                    batch.append(msg, 0)
770            else:
771                return "error: expected set/add/delete but got: " + line, None
772
773        if batch.size() <= 0:
774            return 0, None
775
776        return 0, batch
777
778
779class StdOutSink(Sink):
780    """Emits batches to stdout in memcached ascii protocol."""
781
782    @staticmethod
783    def can_handle(opts, spec):
784        if spec.startswith("stdout:") or spec == "-":
785            opts.threads = 1 # Force 1 thread to not overlap stdout.
786            return True
787        return False
788
789    @staticmethod
790    def check(opts, spec, source_map):
791        return 0, None
792
793    @staticmethod
794    def check_base(opts, spec):
795        if getattr(opts, "destination_vbucket_state", "active") != "active":
796            return ("error: only --destination-vbucket-state=active" +
797                    " is supported by this destination: %s") % (spec)
798
799        op = getattr(opts, "destination_operation", None)
800        if not op in [None, 'set', 'add', 'get']:
801            return ("error: --destination-operation unsupported value: %s" +
802                    "; use set, add, get") % (op)
803
804        # Skip immediate superclass Sink.check_base(),
805        # since StdOutSink can handle different destination operations.
806        return EndPoint.check_base(opts, spec)
807
808    @staticmethod
809    def consume_design(opts, sink_spec, sink_map,
810                       source_bucket, source_map, source_design):
811        if source_design:
812            logging.warn("warning: cannot save bucket design"
813                         " on a stdout destination")
814        return 0
815
816    def consume_batch_async(self, batch):
817        op = self.operation()
818        op_mutate = op in ['set', 'add']
819
820        stdout = sys.stdout
821        msg_visitor = None
822
823        opts_etc = getattr(self.opts, "etc", None)
824        if opts_etc:
825            stdout = opts_etc.get("stdout", sys.stdout)
826            msg_visitor = opts_etc.get("msg_visitor", None)
827
828        mcd_compatible = self.opts.extra.get("mcd_compatible", 1)
829        msg_tuple_format = 0
830        for msg in batch.msgs:
831            if msg_visitor:
832                msg = msg_visitor(msg)
833            if not msg_tuple_format:
834                msg_tuple_format = len(msg)
835            cmd, vbucket_id, key, flg, exp, cas, meta, val = msg[:8]
836            seqno = dtype = nmeta = conf_res = 0
837            if msg_tuple_format > 8:
838                seqno, dtype, nmeta, conf_res = msg[8:]
839            if self.skip(key, vbucket_id):
840                continue
841            if dtype > 2:
842                try:
843                    val = snappy.uncompress(val)
844                except Exception, err:
845                    pass
846            try:
847                if cmd == couchbaseConstants.CMD_TAP_MUTATION or \
848                   cmd == couchbaseConstants.CMD_DCP_MUTATION:
849                    if op_mutate:
850                        # <op> <key> <flags> <exptime> <bytes> [noreply]\r\n
851                        if mcd_compatible:
852                            stdout.write("%s %s %s %s %s\r\n" %
853                                         (op, key, flg, exp, len(val)))
854                        else:
855                            stdout.write("%s %s %s %s %s %s %s %s\r\n" %
856                                         (op, key, flg, exp, len(val), seqno, dtype, conf_res))
857                        stdout.write(val)
858                        stdout.write("\r\n")
859                    elif op == 'get':
860                        stdout.write("get %s\r\n" % (key))
861                elif cmd == couchbaseConstants.CMD_TAP_DELETE or \
862                     cmd == couchbaseConstants.CMD_DCP_DELETE:
863                    if op_mutate:
864                        stdout.write("delete %s\r\n" % (key))
865                elif cmd == couchbaseConstants.CMD_GET:
866                    stdout.write("get %s\r\n" % (key))
867                else:
868                    return "error: StdOutSink - unknown cmd: " + str(cmd), None
869            except IOError:
870                return "error: could not write to stdout", None
871
872        stdout.flush()
873        future = SinkBatchFuture(self, batch)
874        self.future_done(future, 0)
875        return 0, future
876
877
878# --------------------------------------------------
879
880CMD_STR = {
881    couchbaseConstants.CMD_TAP_CONNECT: "TAP_CONNECT",
882    couchbaseConstants.CMD_TAP_MUTATION: "TAP_MUTATION",
883    couchbaseConstants.CMD_TAP_DELETE: "TAP_DELETE",
884    couchbaseConstants.CMD_TAP_FLUSH: "TAP_FLUSH",
885    couchbaseConstants.CMD_TAP_OPAQUE: "TAP_OPAQUE",
886    couchbaseConstants.CMD_TAP_VBUCKET_SET: "TAP_VBUCKET_SET",
887    couchbaseConstants.CMD_TAP_CHECKPOINT_START: "TAP_CHECKPOINT_START",
888    couchbaseConstants.CMD_TAP_CHECKPOINT_END: "TAP_CHECKPOINT_END",
889    couchbaseConstants.CMD_NOOP: "NOOP"
890}
891
892def get_username(username):
893    return username or os.environ.get('CB_REST_USERNAME', '')
894
895def get_password(password):
896    return password or os.environ.get('CB_REST_PASSWORD', '')
897
898def parse_spec(opts, spec, port):
899    """Parse host, port, username, password, path from opts and spec."""
900
901    # Example spec: http://Administrator:password@HOST:8091
902    p = urlparse.urlparse(spec)
903
904    # Example netloc: Administrator:password@HOST:8091
905    #ParseResult tuple(scheme, netloc, path, params, query, fragment)
906    netloc = p[1]
907
908    if not netloc: # When urlparse() can't parse non-http URI's.
909        netloc = spec.split('://')[-1].split('/')[0]
910
911    pair = netloc.split('@') # [ "user:pwsd", "host:port" ].
912    host = (pair[-1] + ":" + str(port)).split(':')[0]
913    port = (pair[-1] + ":" + str(port)).split(':')[1]
914    try:
915       val = int(port)
916    except ValueError:
917       logging.warn("\"" + port + "\" is not int, reset it to default port number")
918       port = 8091
919
920    username = get_username(opts.username)
921    password = get_password(opts.password)
922    if len(pair) > 1:
923        username = username or (pair[0] + ':').split(':')[0]
924        password = password or (pair[0] + ':').split(':')[1]
925
926    return host, port, username, password, p[2]
927
928def rest_request(host, port, user, pswd, ssl, path, method='GET', body='', reason='', headers=None):
929    if reason:
930        reason = "; reason: %s" % (reason)
931    logging.debug("rest_request: %s@%s:%s%s%s" % (user, host, port, path, reason))
932    if ssl:
933        if port not in [couchbaseConstants.SSL_REST_PORT, couchbaseConstants.SSL_QUERY_PORT]:
934            return ("error: invalid port %s used when ssl option is specified") % port, None, None
935        conn = httplib.HTTPSConnection(host, port)
936    else:
937        conn = httplib.HTTPConnection(host, port)
938    try:
939        header = rest_headers(user, pswd, headers)
940        conn.request(method, path, body, header)
941        resp = conn.getresponse()
942    except Exception, e:
943        return ("error: could not access REST API: %s:%s%s" +
944                "; please check source URL, server status, username (-u) and password (-p)" +
945                "; exception: %s%s") % \
946                (host, port, path, e, reason), None, None
947
948    if resp.status in [200, 201, 202, 204, 302]:
949        return None, conn, resp.read()
950
951    conn.close()
952    if resp.status == 401:
953        return ("error: unable to access REST API: %s:%s%s" +
954                "; please check source URL, server status, username (-u) and password (-p)%s") % \
955                (host, port, path, reason), None, None
956
957    return ("error: unable to access REST API: %s:%s%s" +
958            "; please check source URL, server status, username (-u) and password (-p)" +
959            "; response: %s%s") % \
960            (host, port, path, resp.status, reason), None, None
961
962def rest_headers(user, pswd, headers=None):
963    if not headers:
964        headers = {'Content-Type': 'application/json'}
965    if user:
966        auth = 'Basic ' + \
967            string.strip(base64.encodestring(user + ':' + (pswd or '')))
968        headers['Authorization'] = auth
969    return headers
970
971def rest_request_json(host, port, user, pswd, ssl, path, reason=''):
972    err, conn, rest_json = rest_request(host, port, user, pswd, ssl, path,
973                                        reason=reason)
974    if err:
975        return err, None, None
976    if conn:
977        conn.close()
978    try:
979        return None, rest_json, json.loads(rest_json)
980    except ValueError, e:
981        return ("error: could not decode JSON from REST API: %s:%s%s" +
982                "; exception: %s" +
983                "; please check URL, username (-u) and password (-p)") % \
984                (host, port, path, e), None, None
985
986def rest_couchbase(opts, spec):
987    spec = spec.replace('couchbase://', 'http://')
988
989    spec_parts = parse_spec(opts, spec, 8091)
990    host, port, user, pswd, path = spec_parts
991
992    if not path or path == '/':
993        path = '/pools/default/buckets'
994
995    if int(port) in [11209, 11210, 11211]:
996        return "error: invalid port number %s, which is reserved for moxi service" % port, None
997
998    err, rest_json, rest_data = \
999        rest_request_json(host, int(port), user, pswd, opts.ssl, path)
1000    if err:
1001        return err, None
1002
1003    if type(rest_data) == type([]):
1004        rest_buckets = rest_data
1005    else:
1006        rest_buckets = [ rest_data ] # Wrap single bucket in a list.
1007
1008    buckets = []
1009
1010    for bucket in rest_buckets:
1011        if not (bucket and
1012                bucket.get('name', None) and
1013                bucket.get('bucketType', None) and
1014                bucket.get('nodes', None) and
1015                bucket.get('nodeLocator', None)):
1016            return "error: unexpected JSON value from: " + spec + \
1017                " - did you provide the right source URL?", None
1018
1019        if bucket['nodeLocator'] == 'vbucket' and \
1020                (bucket['bucketType'] == 'membase' or
1021                 bucket['bucketType'] == 'couchbase'):
1022            buckets.append(bucket)
1023        else:
1024            logging.warn("skipping bucket that is not a couchbase-bucket: " +
1025                         bucket['name'])
1026
1027    if user is None or pswd is None:
1028        # Check if we have buckets other than the default one
1029        if len(rest_buckets) > 0:
1030            if len(rest_buckets) > 1 or rest_buckets[0].get('name', None) != "default":
1031                return "error: REST username (-u) and password (-p) are required " + \
1032                       "to access all bucket(s)", None
1033    return 0, {'spec': spec, 'buckets': buckets, 'spec_parts': spec_parts}
1034
1035def filter_server(opts, spec, filtor):
1036    spec = spec.replace('couchbase://', 'http://')
1037
1038    spec_parts = parse_spec(opts, spec, 8091)
1039    host, port, user, pswd, path = spec_parts
1040
1041    path = '/pools/default'
1042    err, rest_json, rest_data = \
1043        rest_request_json(host, int(port), user, pswd, opts.ssl, path)
1044    if err:
1045        return err, None
1046
1047    for node in rest_data["nodes"]:
1048        if filtor in node["services"] and node["status"] == "healthy":
1049            return 0, node["hostname"]
1050    return 0, None
1051
1052def filter_bucket_nodes(bucket, spec_parts):
1053    host, port = spec_parts[:2]
1054    if host in ['localhost', '127.0.0.1']:
1055        host = get_ip()
1056    host_port = host + ':' + str(port)
1057    return filter(lambda n: n.get('hostname') == host_port,
1058                  bucket['nodes'])
1059
1060def get_ip():
1061    ip = None
1062    for fname in ['/opt/couchbase/var/lib/couchbase/ip_start',
1063                  '/opt/couchbase/var/lib/couchbase/ip',
1064                  '../var/lib/couchbase/ip_start',
1065                  '../var/lib/couchbase/ip']:
1066        try:
1067            f = open(fname, 'r')
1068            ip = string.strip(f.read())
1069            f.close()
1070            if ip and len(ip):
1071                if ip.find('@'):
1072                    ip = ip.split('@')[1]
1073                break
1074        except:
1075            pass
1076    if not ip or not len(ip):
1077        ip = '127.0.0.1'
1078    return ip
1079
1080def find_source_bucket_name(opts, source_map):
1081    """If the caller didn't specify a bucket_source and
1082       there's only one bucket in the source_map, use that."""
1083    source_bucket = getattr(opts, "bucket_source", None)
1084    if (not source_bucket and
1085        source_map and
1086        source_map['buckets'] and
1087        len(source_map['buckets']) == 1):
1088        source_bucket = source_map['buckets'][0]['name']
1089    if not source_bucket:
1090        return "error: please specify a bucket_source", None
1091    logging.debug("source_bucket: " + source_bucket)
1092    return 0, source_bucket
1093
1094def find_sink_bucket_name(opts, source_bucket):
1095    """Default bucket_destination to the same as bucket_source."""
1096    sink_bucket = getattr(opts, "bucket_destination", None) or source_bucket
1097    if not sink_bucket:
1098        return "error: please specify a bucket_destination", None
1099    logging.debug("sink_bucket: " + sink_bucket)
1100    return 0, sink_bucket
1101
1102def mkdirs(targetpath):
1103    upperdirs = os.path.dirname(targetpath)
1104    if upperdirs and not os.path.exists(upperdirs):
1105        try:
1106            os.makedirs(upperdirs)
1107        except:
1108            return "Cannot create upper directories for file:%s" % targetpath
1109    return 0
1110
1111def hostport(hoststring, default_port=8091):
1112    try:
1113        host, port = hoststring.split(":")
1114        port = int(port)
1115    except ValueError:
1116        host = hoststring
1117        port = default_port
1118    return host, port
1119