1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2009 Christopher Lenz
5# All rights reserved.
6#
7# This software is licensed as described in the file COPYING, which
8# you should have received as part of this distribution.
9
10"""Simple HTTP client implementation based on the ``httplib`` module in the
11standard library.
12"""
13
14from base64 import b64encode
15from datetime import datetime
16import errno
17from httplib import BadStatusLine, HTTPConnection, HTTPSConnection
18import socket
19import time
20try:
21    from cStringIO import StringIO
22except ImportError:
23    from StringIO import StringIO
24import sys
25try:
26    from threading import Lock
27except ImportError:
28    from dummy_threading import Lock
29import urllib
30from urlparse import urlsplit, urlunsplit
31
32from couchdb import json
33
34__all__ = ['HTTPError', 'PreconditionFailed', 'ResourceNotFound',
35           'ResourceConflict', 'ServerError', 'Unauthorized', 'RedirectLimit',
36           'Session', 'Resource']
37__docformat__ = 'restructuredtext en'
38
39
40if sys.version < '2.6':
41
42    class TimeoutMixin:
43        """Helper mixin to add timeout before socket connection"""
44
45        # taken from original python2.5 httplib source code with timeout setting added
46        def connect(self):
47            """Connect to the host and port specified in __init__."""
48            msg = "getaddrinfo returns an empty list"
49            for res in socket.getaddrinfo(self.host, self.port, 0,
50                                          socket.SOCK_STREAM):
51                af, socktype, proto, canonname, sa = res
52                try:
53                    self.sock = socket.socket(af, socktype, proto)
54                    if self.debuglevel > 0:
55                        print "connect: (%s, %s)" % (self.host, self.port)
56
57                    # setting socket timeout
58                    self.sock.settimeout(self.timeout)
59
60                    self.sock.connect(sa)
61                except socket.error, msg:
62                    if self.debuglevel > 0:
63                        print 'connect fail:', (self.host, self.port)
64                    if self.sock:
65                        self.sock.close()
66                    self.sock = None
67                    continue
68                break
69            if not self.sock:
70                raise socket.error, msg
71
72    _HTTPConnection = HTTPConnection
73    _HTTPSConnection = HTTPSConnection
74
75    class HTTPConnection(TimeoutMixin, _HTTPConnection):
76        def __init__(self, *a, **k):
77            timeout = k.pop('timeout', None)
78            _HTTPConnection.__init__(self, *a, **k)
79            self.timeout = timeout
80
81    class HTTPSConnection(TimeoutMixin, _HTTPSConnection):
82        def __init__(self, *a, **k):
83            timeout = k.pop('timeout', None)
84            _HTTPSConnection.__init__(self, *a, **k)
85            self.timeout = timeout
86
87
88class HTTPError(Exception):
89    """Base class for errors based on HTTP status codes >= 400."""
90
91
92class PreconditionFailed(HTTPError):
93    """Exception raised when a 412 HTTP error is received in response to a
94    request.
95    """
96
97
98class ResourceNotFound(HTTPError):
99    """Exception raised when a 404 HTTP error is received in response to a
100    request.
101    """
102
103
104class ResourceConflict(HTTPError):
105    """Exception raised when a 409 HTTP error is received in response to a
106    request.
107    """
108
109
110class ServerError(HTTPError):
111    """Exception raised when an unexpected HTTP error is received in response
112    to a request.
113    """
114
115
116class Unauthorized(HTTPError):
117    """Exception raised when the server requires authentication credentials
118    but either none are provided, or they are incorrect.
119    """
120
121
122class RedirectLimit(Exception):
123    """Exception raised when a request is redirected more often than allowed
124    by the maximum number of redirections.
125    """
126
127
128CHUNK_SIZE = 1024 * 8
129CACHE_SIZE = 10, 75 # some random values to limit memory use
130
131def cache_sort(i):
132    t = time.mktime(time.strptime(i[1][1]['Date'][5:-4], '%d %b %Y %H:%M:%S'))
133    return datetime.fromtimestamp(t)
134
135class ResponseBody(object):
136
137    def __init__(self, resp, callback):
138        self.resp = resp
139        self.callback = callback
140
141    def read(self, size=None):
142        bytes = self.resp.read(size)
143        if size is None or len(bytes) < size:
144            self.close()
145        return bytes
146
147    def close(self):
148        while not self.resp.isclosed():
149            self.resp.read(CHUNK_SIZE)
150        if self.callback:
151            self.callback()
152            self.callback = None
153
154    def __iter__(self):
155        assert self.resp.msg.get('transfer-encoding') == 'chunked'
156        while True:
157            if self.resp.isclosed():
158                break
159            chunksz = int(self.resp.fp.readline().strip(), 16)
160            if not chunksz:
161                self.resp.fp.read(2) #crlf
162                self.resp.close()
163                self.callback()
164                break
165            chunk = self.resp.fp.read(chunksz)
166            for ln in chunk.splitlines():
167                yield ln
168            self.resp.fp.read(2) #crlf
169
170
171RETRYABLE_ERRORS = frozenset([
172    errno.EPIPE, errno.ETIMEDOUT,
173    errno.ECONNRESET, errno.ECONNREFUSED, errno.ECONNABORTED,
174    errno.EHOSTDOWN, errno.EHOSTUNREACH,
175    errno.ENETRESET, errno.ENETUNREACH, errno.ENETDOWN
176])
177
178
179class Session(object):
180
181    def __init__(self, cache=None, timeout=None, max_redirects=5,
182                 retry_delays=[0], retryable_errors=RETRYABLE_ERRORS):
183        """Initialize an HTTP client session.
184
185        :param cache: an instance with a dict-like interface or None to allow
186                      Session to create a dict for caching.
187        :param timeout: socket timeout in number of seconds, or `None` for no
188                        timeout (the default)
189        :param retry_delays: list of request retry delays.
190        """
191        from couchdb import __version__ as VERSION
192        self.user_agent = 'CouchDB-Python/%s' % VERSION
193        if cache is None:
194            cache = {}
195        self.cache = cache
196        self.timeout = timeout
197        self.max_redirects = max_redirects
198        self.perm_redirects = {}
199        self.conns = {} # HTTP connections keyed by (scheme, host)
200        self.lock = Lock()
201        self.retry_delays = list(retry_delays) # We don't want this changing on us.
202        self.retryable_errors = set(retryable_errors)
203
204    def request(self, method, url, body=None, headers=None, credentials=None,
205                num_redirects=0):
206        if url in self.perm_redirects:
207            url = self.perm_redirects[url]
208        method = method.upper()
209
210        if headers is None:
211            headers = {}
212        headers.setdefault('Accept', 'application/json')
213        headers['User-Agent'] = self.user_agent
214
215        cached_resp = None
216        if method in ('GET', 'HEAD'):
217            cached_resp = self.cache.get(url)
218            if cached_resp is not None:
219                etag = cached_resp[1].get('etag')
220                if etag:
221                    headers['If-None-Match'] = etag
222
223        if body is None:
224            headers.setdefault('Content-Length', '0')
225        else:
226            if not isinstance(body, basestring):
227                try:
228                    body = json.encode(body).encode('utf-8')
229                except TypeError:
230                    # Check for somethine file-like or re-raise the exception
231                    # to avoid masking real JSON encoding errors.
232                    if not hasattr(body, 'read'):
233                        raise
234                else:
235                    headers.setdefault('Content-Type', 'application/json')
236            if isinstance(body, basestring):
237                headers.setdefault('Content-Length', str(len(body)))
238            else:
239                headers['Transfer-Encoding'] = 'chunked'
240
241        authorization = basic_auth(credentials)
242        if authorization:
243            headers['Authorization'] = authorization
244
245        path_query = urlunsplit(('', '') + urlsplit(url)[2:4] + ('',))
246        conn = self._get_connection(url)
247
248        def _try_request_with_retries(retries):
249            while True:
250                try:
251                    return _try_request()
252                except socket.error, e:
253                    ecode = e.args[0]
254                    if ecode not in self.retryable_errors:
255                        raise
256                    try:
257                        delay = retries.next()
258                    except StopIteration:
259                        # No more retries, raise last socket error.
260                        raise e
261                    time.sleep(delay)
262                    conn.close()
263
264        def _try_request():
265            try:
266                conn.putrequest(method, path_query, skip_accept_encoding=True)
267                for header in headers:
268                    conn.putheader(header, headers[header])
269                conn.endheaders()
270                if body is not None:
271                    if isinstance(body, str):
272                        conn.send(body)
273                    else: # assume a file-like object and send in chunks
274                        while 1:
275                            chunk = body.read(CHUNK_SIZE)
276                            if not chunk:
277                                break
278                            conn.send(('%x\r\n' % len(chunk)) + chunk + '\r\n')
279                        conn.send('0\r\n\r\n')
280                return conn.getresponse()
281            except BadStatusLine, e:
282                # httplib raises a BadStatusLine when it cannot read the status
283                # line saying, "Presumably, the server closed the connection
284                # before sending a valid response."
285                # Raise as ECONNRESET to simplify retry logic.
286                if e.line == '' or e.line == "''":
287                    raise socket.error(errno.ECONNRESET)
288                else:
289                    raise
290
291        resp = _try_request_with_retries(iter(self.retry_delays))
292        status = resp.status
293
294        # Handle conditional response
295        if status == 304 and method in ('GET', 'HEAD'):
296            resp.read()
297            self._return_connection(url, conn)
298            status, msg, data = cached_resp
299            if data is not None:
300                data = StringIO(data)
301            return status, msg, data
302        elif cached_resp:
303            del self.cache[url]
304
305        # Handle redirects
306        if status == 303 or \
307                method in ('GET', 'HEAD') and status in (301, 302, 307):
308            resp.read()
309            self._return_connection(url, conn)
310            if num_redirects > self.max_redirects:
311                raise RedirectLimit('Redirection limit exceeded')
312            location = resp.getheader('location')
313            if status == 301:
314                self.perm_redirects[url] = location
315            elif status == 303:
316                method = 'GET'
317            return self.request(method, location, body, headers,
318                                num_redirects=num_redirects + 1)
319
320        data = None
321        streamed = False
322
323        # Read the full response for empty responses so that the connection is
324        # in good state for the next request
325        if method == 'HEAD' or resp.getheader('content-length') == '0' or \
326                status < 200 or status in (204, 304):
327            resp.read()
328            self._return_connection(url, conn)
329
330        # Buffer small non-JSON response bodies
331        elif int(resp.getheader('content-length', sys.maxint)) < CHUNK_SIZE:
332            data = resp.read()
333            self._return_connection(url, conn)
334
335        # For large or chunked response bodies, do not buffer the full body,
336        # and instead return a minimal file-like object
337        else:
338            data = ResponseBody(resp,
339                                lambda: self._return_connection(url, conn))
340            streamed = True
341
342        # Handle errors
343        if status >= 400:
344            ctype = resp.getheader('content-type')
345            if data is not None and 'application/json' in ctype:
346                data = json.decode(data)
347                error = data.get('error'), data.get('reason')
348            elif method != 'HEAD':
349                error = resp.read()
350                self._return_connection(url, conn)
351            else:
352                error = ''
353            if status == 401:
354                raise Unauthorized(error)
355            elif status == 404:
356                raise ResourceNotFound(error)
357            elif status == 409:
358                raise ResourceConflict(error)
359            elif status == 412:
360                raise PreconditionFailed(error)
361            else:
362                raise ServerError((status, error))
363
364        # Store cachable responses
365        if not streamed and method == 'GET' and 'etag' in resp.msg:
366            self.cache[url] = (status, resp.msg, data)
367            if len(self.cache) > CACHE_SIZE[1]:
368                self._clean_cache()
369
370        if not streamed and data is not None:
371            data = StringIO(data)
372
373        return status, resp.msg, data
374
375    def _clean_cache(self):
376        ls = sorted(self.cache.iteritems(), key=cache_sort)
377        self.cache = dict(ls[-CACHE_SIZE[0]:])
378
379    def _get_connection(self, url):
380        scheme, host = urlsplit(url, 'http', False)[:2]
381        self.lock.acquire()
382        try:
383            conns = self.conns.setdefault((scheme, host), [])
384            if conns:
385                conn = conns.pop(-1)
386            else:
387                if scheme == 'http':
388                    cls = HTTPConnection
389                elif scheme == 'https':
390                    cls = HTTPSConnection
391                else:
392                    raise ValueError('%s is not a supported scheme' % scheme)
393                conn = cls(host, timeout=self.timeout)
394                conn.connect()
395        finally:
396            self.lock.release()
397        return conn
398
399    def _return_connection(self, url, conn):
400        scheme, host = urlsplit(url, 'http', False)[:2]
401        self.lock.acquire()
402        try:
403            self.conns.setdefault((scheme, host), []).append(conn)
404        finally:
405            self.lock.release()
406
407
408class Resource(object):
409
410    def __init__(self, url, session, headers=None):
411        self.url, self.credentials = extract_credentials(url)
412        if session is None:
413            session = Session()
414        self.session = session
415        self.headers = headers or {}
416
417    def __call__(self, *path):
418        obj = type(self)(urljoin(self.url, *path), self.session)
419        obj.credentials = self.credentials
420        obj.headers = self.headers.copy()
421        return obj
422
423    def delete(self, path=None, headers=None, **params):
424        return self._request('DELETE', path, headers=headers, **params)
425
426    def get(self, path=None, headers=None, **params):
427        return self._request('GET', path, headers=headers, **params)
428
429    def head(self, path=None, headers=None, **params):
430        return self._request('HEAD', path, headers=headers, **params)
431
432    def post(self, path=None, body=None, headers=None, **params):
433        return self._request('POST', path, body=body, headers=headers,
434                             **params)
435
436    def put(self, path=None, body=None, headers=None, **params):
437        return self._request('PUT', path, body=body, headers=headers, **params)
438
439    def delete_json(self, *a, **k):
440        status, headers, data = self.delete(*a, **k)
441        if 'application/json' in headers.get('content-type'):
442            data = json.decode(data.read())
443        return status, headers, data
444
445    def get_json(self, *a, **k):
446        status, headers, data = self.get(*a, **k)
447        if 'application/json' in headers.get('content-type'):
448            data = json.decode(data.read())
449        return status, headers, data
450
451    def post_json(self, *a, **k):
452        status, headers, data = self.post(*a, **k)
453        if 'application/json' in headers.get('content-type'):
454            data = json.decode(data.read())
455        return status, headers, data
456
457    def put_json(self, *a, **k):
458        status, headers, data = self.put(*a, **k)
459        if 'application/json' in headers.get('content-type'):
460            data = json.decode(data.read())
461        return status, headers, data
462
463    def _request(self, method, path=None, body=None, headers=None, **params):
464        all_headers = self.headers.copy()
465        all_headers.update(headers or {})
466        if path is not None:
467            url = urljoin(self.url, path, **params)
468        else:
469            url = urljoin(self.url, **params)
470        return self.session.request(method, url, body=body,
471                                    headers=all_headers,
472                                    credentials=self.credentials)
473
474
475def extract_credentials(url):
476    """Extract authentication (user name and password) credentials from the
477    given URL.
478
479    >>> extract_credentials('http://localhost:5984/_config/')
480    ('http://localhost:5984/_config/', None)
481    >>> extract_credentials('http://joe:secret@localhost:5984/_config/')
482    ('http://localhost:5984/_config/', ('joe', 'secret'))
483    >>> extract_credentials('http://joe%40example.com:secret@localhost:5984/_config/')
484    ('http://localhost:5984/_config/', ('joe@example.com', 'secret'))
485    """
486    parts = urlsplit(url)
487    netloc = parts[1]
488    if '@' in netloc:
489        creds, netloc = netloc.split('@')
490        credentials = tuple(urllib.unquote(i) for i in creds.split(':'))
491        parts = list(parts)
492        parts[1] = netloc
493    else:
494        credentials = None
495    return urlunsplit(parts), credentials
496
497
498def basic_auth(credentials):
499    if credentials:
500        return 'Basic %s' % b64encode('%s:%s' % credentials)
501
502
503def quote(string, safe=''):
504    if isinstance(string, unicode):
505        string = string.encode('utf-8')
506    return urllib.quote(string, safe)
507
508
509def urlencode(data):
510    if isinstance(data, dict):
511        data = data.items()
512    params = []
513    for name, value in data:
514        if isinstance(value, unicode):
515            value = value.encode('utf-8')
516        params.append((name, value))
517    return urllib.urlencode(params)
518
519
520def urljoin(base, *path, **query):
521    """Assemble a uri based on a base, any number of path segments, and query
522    string parameters.
523
524    >>> urljoin('http://example.org', '_all_dbs')
525    'http://example.org/_all_dbs'
526
527    A trailing slash on the uri base is handled gracefully:
528
529    >>> urljoin('http://example.org/', '_all_dbs')
530    'http://example.org/_all_dbs'
531
532    And multiple positional arguments become path parts:
533
534    >>> urljoin('http://example.org/', 'foo', 'bar')
535    'http://example.org/foo/bar'
536
537    All slashes within a path part are escaped:
538
539    >>> urljoin('http://example.org/', 'foo/bar')
540    'http://example.org/foo%2Fbar'
541    >>> urljoin('http://example.org/', 'foo', '/bar/')
542    'http://example.org/foo/%2Fbar%2F'
543
544    >>> urljoin('http://example.org/', None) #doctest:+IGNORE_EXCEPTION_DETAIL
545    Traceback (most recent call last):
546        ...
547    TypeError: argument 2 to map() must support iteration
548    """
549    if base and base.endswith('/'):
550        base = base[:-1]
551    retval = [base]
552
553    # build the path
554    path = '/'.join([''] + [quote(s) for s in path])
555    if path:
556        retval.append(path)
557
558    # build the query string
559    params = []
560    for name, value in query.items():
561        if type(value) in (list, tuple):
562            params.extend([(name, i) for i in value if i is not None])
563        elif value is not None:
564            if value is True:
565                value = 'true'
566            elif value is False:
567                value = 'false'
568            params.append((name, value))
569    if params:
570        retval.extend(['?', urlencode(params)])
571
572    return ''.join(retval)
573
574