1#!/usr/bin/env python
2
3import glob
4import logging
5import os
6import sys
7import socket
8import couchbaseConstants
9from cbcollections import defaultdict
10
11from pump import EndPoint, Source, Batch
12try:
13    import ctypes
14except ImportError:
15    cb_path = '/opt/couchbase/lib/python'
16    while cb_path in sys.path:
17        sys.path.remove(cb_path)
18    try:
19        import ctypes
20    except ImportError:
21        sys.exit('error: could not import ctypes module')
22    else:
23        sys.path.insert(0, cb_path)
24
25MIN_SQLITE_VERSION = '3.3'
26
27import_stmts = (
28    'from pysqlite2 import dbapi2 as sqlite3',
29    'import sqlite3',
30)
31for status, stmt in enumerate(import_stmts):
32    try:
33        exec stmt
34        if sqlite3.sqlite_version >= MIN_SQLITE_VERSION:
35            status = 0
36            break
37    except ImportError:
38        pass
39if status:
40    sys.exit("Error: could not import required version of sqlite3 module")
41
42MBF_VERSION = 2 # sqlite pragma user version for Couchbase 1.8.
43
44class MBFSource(Source):
45    """Can read 1.8 server master and *.mb data files."""
46
47    def __init__(self, opts, spec, source_bucket, source_node,
48                 source_map, sink_map, ctl, cur):
49        super(MBFSource, self).__init__(opts, spec, source_bucket, source_node,
50                                        source_map, sink_map, ctl, cur)
51        self.cursor_todo = None
52        self.cursor_done = False
53
54        self.s = """SELECT vbucket, k, flags, exptime, cas, v, vb_version
55                      FROM `%s`.`%s`"""
56
57    @staticmethod
58    def can_handle(opts, spec):
59        return os.path.isfile(spec) and MBFSource.version(spec) == 2
60
61    @staticmethod
62    def check_base(opts, spec):
63        # Skip immediate superclass Source.check_base(),
64        # since MBFSource can handle different vbucket states.
65        return EndPoint.check_base(opts, spec)
66
67    @staticmethod
68    def check(opts, spec):
69        spec = os.path.normpath(spec)
70        if not os.path.isfile(spec):
71            return "error: backup_dir is not a file: " + spec, None
72
73        db_files = MBFSource.db_files(spec)
74        versions = MBFSource.db_file_versions(db_files)
75        logging.debug(" MBFSource check db file versions: %s" % (versions))
76        if max(versions.values()) < 2:
77            err = ("error: wrong backup/db file versions;\n" +
78                   " either the metadata db file is not specified\n" +
79                   " or the backup files need upgrading to version %s;\n" +
80                   " please use cbdbupgrade or contact support.") \
81                   % (MBF_VERSION)
82            return err, None
83
84        # Map of state string (e.g., 'active') to map of vbucket_id to info.
85        vbucket_states = defaultdict(dict)
86        sql = """SELECT vbid, vb_version, state, checkpoint_id
87                   FROM vbucket_states"""
88        db_file = spec
89        try:
90            db = sqlite3.connect(db_file)
91            cur = db.cursor()
92            for row in cur.execute(sql):
93                vbucket_id = row[0]
94                state = str(row[2])
95                vbucket_states[state][vbucket_id] = {
96                    'vbucket_id': vbucket_id,
97                    'vb_version': row[1],
98                    'state': state,
99                    'checkpoint_id': row[3]
100                    }
101            cur.close()
102            db.close()
103        except sqlite3.DatabaseError, e:
104            pass # A missing vbucket_states table is expected.
105
106        return 0, {'spec': spec,
107                   'buckets':
108                       [{'name': os.path.basename(spec),
109                         'nodes': [{'hostname': 'N/A',
110                                    'vbucket_states': vbucket_states
111                                    }]}]}
112
113    @staticmethod
114    def db_file_versions(db_files):
115        rv = {}
116        for db_file in db_files:
117            rv[db_file] = MBFSource.version(db_file)
118        return rv
119
120    @staticmethod
121    def version(db_file):
122        try:
123            return int(MBFSource.run_sql(db_file, "PRAGMA user_version;")[0])
124        except sqlite3.DatabaseError, e:
125            logging.error("error: could not access user_version from: %s" +
126                          "; exception: %s" +
127                          "; perhaps it is being used by another program" +
128                          " like couchbase-server", db_file, e)
129            return 0
130
131    @staticmethod
132    def db_files(spec):
133        return [spec] + glob.glob(spec + "-*.mb")
134
135    @staticmethod
136    def run_sql(db_file, sql):
137        db = sqlite3.connect(db_file)
138        cur = db.cursor()
139        cur.execute(sql)
140        rv = cur.fetchone()
141        cur.close()
142        db.close()
143        return rv
144
145    @staticmethod
146    def provide_design(opts, source_spec, source_bucket, source_map):
147        return 0, None
148
149    def provide_batch(self):
150        if self.cursor_done:
151            return 0, None
152
153        batch = Batch(self)
154
155        batch_max_size = self.opts.extra['batch_max_size']
156        batch_max_bytes = self.opts.extra['batch_max_bytes']
157
158        source_vbucket_state = \
159            getattr(self.opts, 'source_vbucket_state', 'active')
160
161        try:
162            if self.cursor_todo is None:
163                rv, db, attached_dbs, table_dbs, vbucket_states = self.connect_db()
164                if rv != 0:
165                    return rv, None
166
167                # Determine which db the state table is in.
168                try:
169                    (state_db,) = table_dbs[u'vbucket_states']
170                except ValueError:
171                    db.close()
172                    return "error: no unique vbucket_states table", None
173
174                kv_names = []
175                for kv_name, db_name in table_dbs.iteritems():
176                    if (self.opts.id is None and
177                        not kv_name.startswith('kv_')):
178                        continue
179                    if (self.opts.id is not None and
180                        kv_name != "kv_%s" % (self.opts.id) ):
181                        continue
182                    kv_names.append(kv_name)
183
184                db_kv_names = []
185                for kv_name in sorted(kv_names,
186                                      key=lambda x: int(x.split('_')[-1])):
187                    for db_name in sorted(table_dbs[kv_name]):
188                        db_kv_names.append((db_name, kv_name))
189
190                self.cursor_todo = (db, db_kv_names, None, vbucket_states)
191
192            db, db_kv_names, cursor, vbucket_states = self.cursor_todo
193            if not db:
194                self.cursor_done = True
195                self.cursor_todo = None
196                return 0, None
197
198            while (not self.cursor_done and
199                   batch.size() < batch_max_size and
200                   batch.bytes < batch_max_bytes):
201                if not cursor:
202                    if not db_kv_names:
203                        self.cursor_done = True
204                        self.cursor_todo = None
205                        db.close()
206                        break
207
208                    db_name, kv_name = db_kv_names.pop()
209                    vbucket_id = int(kv_name.split('_')[-1])
210                    if not vbucket_states[source_vbucket_state].has_key(vbucket_id):
211                        break
212
213                    logging.debug("  MBFSource db/kv table: %s/%s" %
214                                  (db_name, kv_name))
215                    cursor = db.cursor()
216                    cursor.execute(self.s % (db_name, kv_name))
217                    self.cursor_todo = (db, db_kv_names, cursor, vbucket_states)
218
219                row = cursor.fetchone()
220                if row:
221                    vbucket_id = row[0]
222                    key = row[1]
223                    flg = row[2]
224                    exp = row[3]
225                    cas = row[4]
226                    val = row[5]
227                    version = int(row[6])
228
229                    if self.skip(key, vbucket_id):
230                        continue
231
232                    if version != vbucket_states[source_vbucket_state][vbucket_id]:
233                        continue
234
235                    meta = ''
236                    flg = socket.ntohl(ctypes.c_uint32(flg).value)
237                    batch.append((couchbaseConstants.CMD_TAP_MUTATION,
238                                  vbucket_id, key, flg, exp, cas, meta, val, 0, 0, 0), len(val))
239                else:
240                    cursor.close()
241                    self.cursor_todo = (db, db_kv_names, None, vbucket_states)
242                    break # Close the batch; next pass hits new db_name/kv_name.
243
244        except Exception, e:
245            self.cursor_done = True
246            self.cursor_todo = None
247            return "error: MBFSource exception: " + str(e), None
248
249        return 0, batch
250
251    @staticmethod
252    def total_msgs(opts, source_bucket, source_node, source_map):
253        total = None
254
255        vb_state = getattr(opts, "source_vbucket_state", None)
256        if vb_state not in ["active", "replica"]:
257            return 0, total
258
259        try:
260            spec = source_map['spec']
261            db = sqlite3.connect(spec)
262            cursor = db.cursor()
263
264            stmt = "SELECT value FROM stats_snap where name like 'vb_%s_curr_items'" % vb_state
265            cursor.execute(stmt)
266            row = cursor.fetchone()
267            if row:
268                #Either we can find the stats in the first row, or we don't.
269                total = int(str(row[0]))
270
271            cursor.close()
272            db.close()
273        except Exception,e:
274            pass
275
276        return 0, total
277
278    def connect_db(self):
279        #Build vbucket state hash table
280        vbucket_states = defaultdict(dict)
281        sql = """SELECT vbid, vb_version, state FROM vbucket_states"""
282        try:
283            db = sqlite3.connect(self.spec)
284            cur = db.cursor()
285            for row in cur.execute(sql):
286                vbucket_id = int(row[0])
287                vb_version = int(row[1])
288                state = str(row[2])
289                vbucket_states[state][vbucket_id] = vb_version
290            cur.close()
291            db.close()
292        except sqlite3.DatabaseError, e:
293            return "error: no vbucket_states table was found;" + \
294                   " check if db files are correct", None, None, None
295
296        db = sqlite3.connect(':memory:')
297        logging.debug("  MBFSource connect_db: %s" % self.spec)
298
299        db_files = MBFSource.db_files(self.spec)
300        logging.debug("  MBFSource db_files: %s" % db_files)
301
302        attached_dbs = ["db%s" % (i) for i in xrange(len(db_files))]
303        db.executemany("attach ? as ?", zip(db_files, attached_dbs))
304
305        # Find all tables, filling a table_name => db_name map.
306        table_dbs = {}
307        for db_name in attached_dbs:
308            cursor = db.cursor()
309            cursor.execute("SELECT name FROM %s.sqlite_master"
310                           " WHERE type = 'table'" % db_name)
311            for (table_name,) in cursor:
312                table_dbs.setdefault(table_name, []).append(db_name)
313            cursor.close()
314
315        if not filter(lambda table_name: table_name.startswith("kv_"),
316                      table_dbs):
317            db.close()
318            return "error: no kv data was found;" + \
319                " check if db files are correct", None, None, None
320
321        logging.debug("  MBFSource total # tables: %s" % len(table_dbs))
322        return 0, db, attached_dbs, table_dbs, vbucket_states
323