xref: /4.6.0/couchbase-cli/cluster_manager.py (revision 3a439edd)
1"""Management API's for Couchbase Cluster"""
2
3import requests
4import csv
5import StringIO
6
7N1QL_SERVICE = 'n1ql'
8INDEX_SERVICE = 'index'
9MGMT_SERVICE = 'mgmt'
10FTS_SERVICE = 'fts'
11
12DEFAULT_REQUEST_TIMEOUT = 60
13
14# Remove this once we can verify SSL certificates
15requests.packages.urllib3.disable_warnings()
16
17def request(f):
18    def g(*args, **kwargs):
19        cm = args[0]
20        url = args[1]
21        try:
22            return f(*args, **kwargs)
23        except requests.exceptions.ConnectionError, e:
24            return None, ['Unable to connect to host at %s' % cm.hostname]
25        except requests.exceptions.ReadTimeout, e:
26            return None, ['Request to host `%s` timed out after %d seconds' % (url, cm.timeout)]
27    return g
28
29
30class ServiceNotAvailableException(Exception):
31    """An exception raised when a service does not exist in the target cluster"""
32
33    def __init__(self, service):
34        Exception.__init__(self, "Service %s not available in target cluster" % service)
35
36class ClusterManager(object):
37    """A set of REST API's for managing a Couchbase cluster"""
38
39    def __init__(self, host, port, username, password, ssl=False, timeout=DEFAULT_REQUEST_TIMEOUT):
40        if ssl:
41            self.hostname = 'https://%s:%s' % (host, str(port))
42        else:
43            self.hostname = 'http://%s:%s' % (host, str(port))
44
45        self.username = username
46        self.password = password
47        self.timeout = timeout
48        self.ssl = ssl
49
50    def n1ql_query(self, stmt, args=None):
51        """Sends a N1QL query
52
53        Sends a N1QL query and returns the result of the query. Raises a
54        ServiceNotAvailable exception if the target cluster is no running the n1ql
55        service."""
56
57        hosts, errors = self.get_hostnames_for_service(N1QL_SERVICE)
58        if errors:
59            return None, errors
60
61        if not hosts:
62            raise ServiceNotAvailableException(N1QL_SERVICE)
63
64        url = hosts[0] + '/query/service'
65        body = {'statement': str(stmt)}
66
67        if args:
68            body['args'] = str(args)
69
70        result, errors = self._post_form_encoded(url, body)
71        if errors:
72            return None, errors
73
74        return result, None
75
76    def get_hostnames_for_service(self, service_name):
77        """ Gets all hostnames that run a service
78
79        Gets all hostnames for specified service and returns a list of strings
80        in the form "http://hostname:port". If the ClusterManager is configured
81        to use SSL/TLS then "https://" is prefixed to each name instead of
82        "http://"."""
83        url = self.hostname + '/pools/default/nodeServices'
84        data, errors = self._get(url)
85        if errors:
86            return None, errors
87
88        hosts = []
89        for node in data['nodesExt']:
90            node_host = '127.0.0.1'
91            if 'hostname' in node:
92                node_host = node['hostname']
93
94            http_prefix = 'http://'
95            fts_port_name = 'fts'
96            n1ql_port_name = 'n1ql'
97            mgmt_port_name = 'mgmt'
98            index_port_name = 'indexHttp'
99
100            if self.ssl:
101                http_prefix = 'https://'
102                n1ql_port_name = 'n1qlSSL'
103                mgmt_port_name = 'mgmtSSL'
104                # The is no ssl port for the index or fts services
105
106            if service_name == MGMT_SERVICE and mgmt_port_name in node['services']:
107                hosts.append(http_prefix + node_host + ':' + str(node['services'][mgmt_port_name]))
108
109            if service_name == N1QL_SERVICE and n1ql_port_name in node['services']:
110                hosts.append(http_prefix + node_host + ':' + str(node['services'][n1ql_port_name]))
111
112            if service_name == INDEX_SERVICE and index_port_name in node['services']:
113                hosts.append(http_prefix + node_host + ':' + str(node['services'][index_port_name]))
114
115            if service_name == FTS_SERVICE and fts_port_name in node['services']:
116                hosts.append(http_prefix + node_host + ':' + str(node['services'][fts_port_name]))
117
118        return hosts, None
119
120    def pools(self):
121        """ Retrieves information about Couchbase management pools
122
123        Returns Couchbase pools data"""
124        url = self.hostname + '/pools'
125        return self._get(url)
126
127    def get_server_groups(self):
128        url = self.hostname + '/pools/default/serverGroups'
129        return self._get(url)
130
131    def get_server_group(self, groupName):
132        groups, errors = self.get_server_groups()
133        if errors:
134            return None, error
135
136        if not groups or not groups["groups"] or groups["groups"] == 0:
137            return None, ["No server groups found"]
138
139        if groupName:
140            for group in groups["groups"]:
141                if group["name"] == groupName:
142                    return group, None
143            return None, ["Group `%s` not found" % groupName]
144        else:
145            return groups["groups"][0], None
146
147    def add_server(self, add_server, groupName, username, password, services):
148        group, errors = self.get_server_group(groupName)
149        if errors:
150            return None, errors
151
152        url = self.hostname + group["addNodeURI"]
153        params = { "hostname": add_server,
154                   "user": username,
155                   "password": password,
156                   "services": services }
157
158        return self._post_form_encoded(url, params)
159
160    def create_bucket(self, bucket, ramQuotaMB, authType, saslPassword,
161                      replicaNumber, proxyPort, bucketType):
162        url = self.hostname + '/pools/default/buckets'
163
164        params = dict()
165        if authType == 'none':
166            params = { "name": bucket,
167                       "ramQuotaMB": ramQuotaMB,
168                       "authType": authType,
169                       "replicaNumber": replicaNumber,
170                       "proxyPort": proxyPort,
171                       "bucketType": bucketType }
172
173        elif authType == 'sasl':
174            params = { "name": bucket,
175                       "ramQuotaMB": ramQuotaMB,
176                       "authType": authType,
177                       "replicaNumber": replicaNumber,
178                       "proxyPort": 0,
179                       "bucketType": bucketType }
180
181        return self._post_form_encoded(url, params)
182
183    def list_buckets(self):
184        url = self.hostname + '/pools/default/buckets'
185        result, errors = self._get(url)
186        if errors:
187            return None, errors
188
189        names = list()
190        for bucket in result:
191            names.append(bucket["name"])
192
193        return names, None
194
195    def set_index_settings(self, storageMode):
196        """ Sets global index settings"""
197        params = dict()
198        params["storageMode"] = storageMode
199
200        url = self.hostname + '/settings/indexes'
201        return self._post_form_encoded(url, params)
202
203    def index_settings(self):
204        """ Retrieves the index settings
205
206            Returns a map of all global index settings"""
207        url = self.hostname + '/settings/indexes'
208        return self._get(url)
209
210    def rotate_master_pwd(self):
211        url = self.hostname + '/node/controller/rotateDataKey'
212        return self._post_form_encoded(url, None)
213
214    def set_master_pwd(self, password):
215        url = self.hostname + '/node/controller/changeMasterPassword'
216        params = { "newPassword": password }
217        return self._post_form_encoded(url, params)
218
219    def setRoles(self,userList,roleList,userNameList):
220        # we take a comma-delimited list of roles that needs to go into a dictionary
221        paramDict = {"roles" : roleList}
222        userIds = []
223        userNames = []
224        userF = StringIO.StringIO(userList)
225        for idList in csv.reader(userF, delimiter=','):
226            userIds.extend(idList)
227
228        # did they specify user names?
229        if userNameList != None:
230            userNameF = StringIO.StringIO(userNameList)
231            for nameList in csv.reader(userNameF, delimiter=','):
232                userNames.extend(nameList)
233            if len(userNames) != len(userIds):
234                return None, ["Error: specified %d user ids and %d user names, must have the same number of each." %  (len(userIds),len(userNames))]
235
236        # did they specify user names?
237        # but we need a separate REST call for each user in the comma-delimited user list
238        for index in range(len(userIds)):
239            user = userIds[index]
240            paramDict["id"] = user
241            if len(userNames) > 0:
242                paramDict["name"] = userNames[index]
243            url = self.hostname + '/settings/rbac/users/' + user
244            data, errors = self._put(url,paramDict)
245            if errors:
246                return data, errors
247
248        return data, errors
249
250    def deleteRoles(self,userList):
251        # need a separate REST call for each user in the comma-delimited user list
252        userF = StringIO.StringIO(userList)
253        reader = csv.reader(userF, delimiter=',')
254        for users in reader:
255            for user in users:
256                url = self.hostname + '/settings/rbac/users/' + user
257                data, errors = self._delete(url)
258                if errors:
259                    return data, errors
260
261        return data, errors
262
263    def getRoles(self):
264        url = self.hostname + '/settings/rbac/users'
265        data, errors = self._get(url)
266
267        return data, errors
268
269    def myRoles(self):
270        url = self.hostname + '/whoami'
271        data, errors = self._get(url)
272
273        return data, errors
274
275    def retrieve_cluster_certificate(self, extended=False):
276        """ Retrieves the current cluster certificate
277
278        Gets the current cluster certificate. If extended is set tot True then
279        we return the extended certificate which contains the certificate type,
280        certicicate key, expiration, subject, and warnings."""
281        url = self.hostname + '/pools/default/certificate'
282        if extended:
283            url += '?extended=true'
284        return self._get(url)
285
286    def regenerate_cluster_certificate(self):
287        """ Regenerates the cluster certificate
288
289        Regenerates the cluster certificate and returns the new certificate."""
290        url = self.hostname + '/controller/regenerateCertificate'
291        return self._post_form_encoded(url, None)
292
293    def upload_cluster_certificate(self, certificate):
294        """ Uploads a new cluster certificate"""
295        url = self.hostname + '/controller/uploadClusterCA'
296        return self._post_form_encoded(url, certificate)
297
298    def retrieve_node_certificate(self, node):
299        """ Retrieves the current node certificate
300
301        Returns the current node certificate"""
302        url = self.hostname + '/pools/default/certificate/node/' + node
303        return self._get(url)
304
305    def set_node_certificate(self):
306        """Activates the current node certificate
307
308        Grabs chain.pem and pkey.pem from the <data folder>/inbox/ directory and
309        applies them to the node. chain.pem contains the chain encoded certificates
310        starting from the node certificat and ending with the last intermediate
311        certificate before cluster CA. pkey.pem contains the pem encoded private
312        key for node certifiactes. Both files should exist on the server before
313        this API is called."""
314        url = self.hostname + '/node/controller/reloadCertificate'
315        return self._post_form_encoded(url, None)
316
317    # Low level methods for basic HTML operations
318
319    @request
320    def _get(self, url):
321        response = requests.get(url, auth=(self.username, self.password), verify=False,
322                                timeout=self.timeout)
323        return _handle_response(response)
324
325    @request
326    def _post_form_encoded(self, url, params):
327        response = requests.post(url, auth=(self.username, self.password), data=params,
328                                 verify=False, timeout=self.timeout)
329        return _handle_response(response)
330
331    @request
332    def _put(self, url, params):
333        response = requests.put(url, params, auth=(self.username, self.password),
334                                verify=False, timeout=self.timeout)
335        return _handle_response(response)
336
337    @request
338    def _delete(self, url):
339        response = requests.delete(url, auth=(self.username, self.password),
340                                   verify=False, timeout=self.timeout)
341        return _handle_response(response)
342
343
344def _handle_response(response):
345    if response.status_code in [200, 202]:
346        if 'Content-Type' not in response.headers:
347            return "", None
348        if 'application/json' in response.headers['Content-Type']:
349            return response.json(), None
350        else:
351            return response.text, None
352    elif response.status_code in [400, 404]:
353        if 'application/json' in response.headers['Content-Type']:
354            errors = response.json()
355            if isinstance(errors, list):
356                return None, errors
357        return None, [response.text]
358    elif response.status_code == 401:
359        return None, ['ERROR: unable to access the REST API - please check your username' +
360                      '(-u) and password (-p)']
361    elif response.status_code == 500:
362        return None, ['ERROR: Internal server error, please retry your request']
363    else:
364        return None, ['Error: Recieved unexpected status %d' % response.status_code]
365