1#!/usr/bin/env python
2"""
3Binary memcached test client.
4
5Copyright (c) 2007  Dustin Sallings <dustin@spy.net>
6"""
7
8import sys
9import time
10import hmac
11import socket
12import random
13import struct
14import exceptions
15
16from memcacheConstants import REQ_MAGIC_BYTE, RES_MAGIC_BYTE
17from memcacheConstants import REQ_PKT_FMT, RES_PKT_FMT, MIN_RECV_PACKET
18from memcacheConstants import SET_PKT_FMT, DEL_PKT_FMT, INCRDECR_RES_FMT
19import memcacheConstants
20
21class MemcachedError(exceptions.Exception):
22    """Error raised when a command fails."""
23
24    def __init__(self, status, msg):
25        supermsg='Memcached error #' + `status`
26        if msg: supermsg += ":  " + msg
27        exceptions.Exception.__init__(self, supermsg)
28
29        self.status=status
30        self.msg=msg
31
32    def __repr__(self):
33        return "<MemcachedError #%d ``%s''>" % (self.status, self.msg)
34
35class MemcachedClient(object):
36    """Simple memcached client."""
37
38    vbucketId = 0
39
40    def __init__(self, host='127.0.0.1', port=11211):
41        self.s=socket.socket(socket.AF_INET, socket.SOCK_STREAM)
42        self.s.connect_ex((host, port))
43        self.r=random.Random()
44
45    def close(self):
46        self.s.close()
47
48    def __del__(self):
49        self.close()
50
51    def _sendCmd(self, cmd, key, val, opaque, extraHeader='', cas=0):
52        dtype=0
53        msg=struct.pack(REQ_PKT_FMT, REQ_MAGIC_BYTE,
54            cmd, len(key), len(extraHeader), dtype, self.vbucketId,
55                len(key) + len(extraHeader) + len(val), opaque, cas)
56        self.s.send(msg + extraHeader + key + val)
57
58    def _handleKeyedResponse(self, myopaque):
59        response = ""
60        while len(response) < MIN_RECV_PACKET:
61            response += self.s.recv(MIN_RECV_PACKET - len(response))
62        assert len(response) == MIN_RECV_PACKET
63        magic, cmd, keylen, extralen, dtype, errcode, remaining, opaque, cas=\
64            struct.unpack(RES_PKT_FMT, response)
65
66        rv = ""
67        while remaining > 0:
68            data = self.s.recv(remaining)
69            rv += data
70            remaining -= len(data)
71
72        assert (magic in (RES_MAGIC_BYTE, REQ_MAGIC_BYTE)), "Got magic: %d" % magic
73        assert myopaque is None or opaque == myopaque, \
74            "expected opaque %x, got %x" % (myopaque, opaque)
75        if errcode != 0:
76            raise MemcachedError(errcode,  rv)
77        return cmd, opaque, cas, keylen, extralen, rv
78
79    def _handleSingleResponse(self, myopaque):
80        cmd, opaque, cas, keylen, extralen, data = self._handleKeyedResponse(myopaque)
81        return opaque, cas, data
82
83    def _doCmd(self, cmd, key, val, extraHeader='', cas=0):
84        """Send a command and await its response."""
85        opaque=self.r.randint(0, 2**32)
86        self._sendCmd(cmd, key, val, opaque, extraHeader, cas)
87        return self._handleSingleResponse(opaque)
88
89    def _mutate(self, cmd, key, exp, flags, cas, val):
90        return self._doCmd(cmd, key, val, struct.pack(SET_PKT_FMT, flags, exp),
91            cas)
92
93    def _cat(self, cmd, key, cas, val):
94        return self._doCmd(cmd, key, val, '', cas)
95
96    def append(self, key, value, cas=0):
97        return self._cat(memcacheConstants.CMD_APPEND, key, cas, value)
98
99    def prepend(self, key, value, cas=0):
100        return self._cat(memcacheConstants.CMD_PREPEND, key, cas, value)
101
102    def __incrdecr(self, cmd, key, amt, init, exp):
103        something, cas, val=self._doCmd(cmd, key, '',
104            struct.pack(memcacheConstants.INCRDECR_PKT_FMT, amt, init, exp))
105        return struct.unpack(INCRDECR_RES_FMT, val)[0], cas
106
107    def incr(self, key, amt=1, init=0, exp=0):
108        """Increment or create the named counter."""
109        return self.__incrdecr(memcacheConstants.CMD_INCR, key, amt, init, exp)
110
111    def decr(self, key, amt=1, init=0, exp=0):
112        """Decrement or create the named counter."""
113        return self.__incrdecr(memcacheConstants.CMD_DECR, key, amt, init, exp)
114
115    def set(self, key, exp, flags, val):
116        """Set a value in the memcached server."""
117        return self._mutate(memcacheConstants.CMD_SET, key, exp, flags, 0, val)
118
119    def add(self, key, exp, flags, val):
120        """Add a value in the memcached server iff it doesn't already exist."""
121        return self._mutate(memcacheConstants.CMD_ADD, key, exp, flags, 0, val)
122
123    def replace(self, key, exp, flags, val):
124        """Replace a value in the memcached server iff it already exists."""
125        return self._mutate(memcacheConstants.CMD_REPLACE, key, exp, flags, 0,
126            val)
127
128    def __parseGet(self, data):
129        flags=struct.unpack(memcacheConstants.GET_RES_FMT, data[-1][:4])[0]
130        return flags, data[1], data[-1][4:]
131
132    def get(self, key):
133        """Get the value for a given key within the memcached server."""
134        parts=self._doCmd(memcacheConstants.CMD_GET, key, '')
135        return self.__parseGet(parts)
136
137    def cas(self, key, exp, flags, oldVal, val):
138        """CAS in a new value for the given key and comparison value."""
139        self._mutate(memcacheConstants.CMD_SET, key, exp, flags,
140            oldVal, val)
141
142    def version(self):
143        """Get the value for a given key within the memcached server."""
144        return self._doCmd(memcacheConstants.CMD_VERSION, '', '')
145
146    def sasl_mechanisms(self):
147        """Get the supported SASL methods."""
148        return set(self._doCmd(memcacheConstants.CMD_SASL_LIST_MECHS,
149                               '', '')[2].split(' '))
150
151    def sasl_auth_start(self, mech, data):
152        """Start a sasl auth session."""
153        return self._doCmd(memcacheConstants.CMD_SASL_AUTH, mech, data)
154
155    def sasl_auth_plain(self, user, password, foruser=''):
156        """Perform plain auth."""
157        return self.sasl_auth_start('PLAIN', '\0'.join([foruser, user, password]))
158
159    def sasl_auth_cram_md5(self, user, password):
160        """Start a plan auth session."""
161        try:
162            self.sasl_auth_start('CRAM-MD5', '')
163        except MemcachedError, e:
164            if e.status != memcacheConstants.ERR_AUTH_CONTINUE:
165                raise
166            challenge = e.msg
167
168        dig = hmac.HMAC(password, challenge).hexdigest()
169        return self._doCmd(memcacheConstants.CMD_SASL_STEP, 'CRAM-MD5',
170                           user + ' ' + dig)
171
172    def stop_persistence(self):
173        return self._doCmd(memcacheConstants.CMD_STOP_PERSISTENCE, '', '')
174
175    def start_persistence(self):
176        return self._doCmd(memcacheConstants.CMD_START_PERSISTENCE, '', '')
177
178    def set_flush_param(self, key, val):
179        print "setting flush param:", key, val
180        return self._doCmd(memcacheConstants.CMD_SET_FLUSH_PARAM, key, val)
181
182    def stop_replication(self):
183        return self._doCmd(memcacheConstants.CMD_STOP_REPLICATION, '', '')
184
185    def start_replication(self):
186        return self._doCmd(memcacheConstants.CMD_START_REPLICATION, '', '')
187
188    def set_tap_param(self, key, val):
189        print "setting tap param:", key, val
190        return self._doCmd(memcacheConstants.CMD_SET_TAP_PARAM, key, val)
191
192    def set_vbucket_state(self, vbucket, state):
193        return self._doCmd(memcacheConstants.CMD_SET_VBUCKET_STATE,
194                           str(vbucket), state)
195
196    def delete_vbucket(self, vbucket):
197        return self._doCmd(memcacheConstants.CMD_DELETE_VBUCKET, str(vbucket), '')
198
199    def evict_key(self, key):
200        return self._doCmd(memcacheConstants.CMD_EVICT_KEY, key, '')
201
202    def getMulti(self, keys):
203        """Get values for any available keys in the given iterable.
204
205        Returns a dict of matched keys to their values."""
206        opaqued=dict(enumerate(keys))
207        terminal=len(opaqued)+10
208        # Send all of the keys in quiet
209        for k,v in opaqued.iteritems():
210            self._sendCmd(memcacheConstants.CMD_GETQ, v, '', k)
211
212        self._sendCmd(memcacheConstants.CMD_NOOP, '', '', terminal)
213
214        # Handle the response
215        rv={}
216        done=False
217        while not done:
218            opaque, cas, data=self._handleSingleResponse(None)
219            if opaque != terminal:
220                rv[opaqued[opaque]]=self.__parseGet((opaque, cas, data))
221            else:
222                done=True
223
224        return rv
225
226    def stats(self, sub=''):
227        """Get stats."""
228        opaque=self.r.randint(0, 2**32)
229        self._sendCmd(memcacheConstants.CMD_STAT, sub, '', opaque)
230        done = False
231        rv = {}
232        while not done:
233            cmd, opaque, cas, klen, extralen, data = self._handleKeyedResponse(None)
234            if klen:
235                rv[data[0:klen]] = data[klen:]
236            else:
237                done = True
238        return rv
239
240    def noop(self):
241        """Send a noop command."""
242        return self._doCmd(memcacheConstants.CMD_NOOP, '', '')
243
244    def delete(self, key, cas=0):
245        """Delete the value for a given key within the memcached server."""
246        return self._doCmd(memcacheConstants.CMD_DELETE, key, '', '', cas)
247
248    def flush(self, timebomb=0):
249        """Flush all storage in a memcached instance."""
250        return self._doCmd(memcacheConstants.CMD_FLUSH, '', '',
251            struct.pack(memcacheConstants.FLUSH_PKT_FMT, timebomb))
252