1#!/usr/bin/env python
2"""
3A memcached test server.
4
5Copyright (c) 2007  Dustin Sallings <dustin@spy.net>
6"""
7
8import asyncore
9import random
10import string
11import socket
12import struct
13import time
14import hmac
15import heapq
16
17import memcacheConstants
18
19from memcacheConstants import MIN_RECV_PACKET, REQ_PKT_FMT, RES_PKT_FMT
20from memcacheConstants import INCRDECR_RES_FMT
21from memcacheConstants import REQ_MAGIC_BYTE, RES_MAGIC_BYTE, EXTRA_HDR_FMTS
22
23VERSION="1.0"
24
25class BaseBackend(object):
26    """Higher-level backend (processes commands and stuff)."""
27
28    # Command IDs to method names.  This is used to build a dispatch dict on
29    # the fly.
30    CMDS={
31        memcacheConstants.CMD_GET: 'handle_get',
32        memcacheConstants.CMD_GETQ: 'handle_getq',
33        memcacheConstants.CMD_SET: 'handle_set',
34        memcacheConstants.CMD_ADD: 'handle_add',
35        memcacheConstants.CMD_REPLACE: 'handle_replace',
36        memcacheConstants.CMD_DELETE: 'handle_delete',
37        memcacheConstants.CMD_INCR: 'handle_incr',
38        memcacheConstants.CMD_DECR: 'handle_decr',
39        memcacheConstants.CMD_QUIT: 'handle_quit',
40        memcacheConstants.CMD_FLUSH: 'handle_flush',
41        memcacheConstants.CMD_NOOP: 'handle_noop',
42        memcacheConstants.CMD_VERSION: 'handle_version',
43        memcacheConstants.CMD_APPEND: 'handle_append',
44        memcacheConstants.CMD_PREPEND: 'handle_prepend',
45        memcacheConstants.CMD_SASL_LIST_MECHS: 'handle_sasl_mechs',
46        memcacheConstants.CMD_SASL_AUTH: 'handle_sasl_auth',
47        memcacheConstants.CMD_SASL_STEP: 'handle_sasl_step',
48        }
49
50    def __init__(self):
51        self.handlers={}
52        self.sched=[]
53
54        for id, method in self.CMDS.iteritems():
55            self.handlers[id]=getattr(self, method, self.handle_unknown)
56
57    def _splitKeys(self, fmt, keylen, data):
58        """Split the given data into the headers as specified in the given
59        format, the key, and the data.
60
61        Return (hdrTuple, key, data)"""
62        hdrSize=struct.calcsize(fmt)
63        assert hdrSize <= len(data), "Data too short for " + fmt + ': ' + `data`
64        hdr=struct.unpack(fmt, data[:hdrSize])
65        assert len(data) >= hdrSize + keylen
66        key=data[hdrSize:keylen+hdrSize]
67        assert len(key) == keylen, "len(%s) == %d, expected %d" \
68            % (key, len(key), keylen)
69        val=data[keylen+hdrSize:]
70        return hdr, key, val
71
72    def _error(self, which, msg):
73        return which, 0, msg
74
75    def processCommand(self, cmd, keylen, vb, cas, data):
76        """Entry point for command processing.  Lower level protocol
77        implementations deliver values here."""
78
79        now=time.time()
80        while self.sched and self.sched[0][0] <= now:
81            print "Running delayed job."
82            heapq.heappop(self.sched)[1]()
83
84        hdrs, key, val=self._splitKeys(EXTRA_HDR_FMTS.get(cmd, ''),
85            keylen, data)
86
87        return self.handlers.get(cmd, self.handle_unknown)(cmd, hdrs, key,
88            cas, val)
89
90    def handle_noop(self, cmd, hdrs, key, cas, data):
91        """Handle a noop"""
92        print "Noop"
93        return 0, 0, ''
94
95    def handle_unknown(self, cmd, hdrs, key, cas, data):
96        """invoked for any unknown command."""
97        return self._error(memcacheConstants.ERR_UNKNOWN_CMD,
98            "The command %d is unknown" % cmd)
99
100class DictBackend(BaseBackend):
101    """Sample backend implementation with a non-expiring dict."""
102
103    def __init__(self):
104        super(DictBackend, self).__init__()
105        self.storage={}
106        self.held_keys={}
107        self.challenge = ''.join(random.sample(string.ascii_letters
108                                               + string.digits, 32))
109
110    def __lookup(self, key):
111        rv=self.storage.get(key, None)
112        if rv:
113            now=time.time()
114            if now >= rv[1]:
115                print key, "expired"
116                del self.storage[key]
117                rv=None
118        else:
119            print "Miss looking up", key
120        return rv
121
122    def handle_get(self, cmd, hdrs, key, cas, data):
123        val=self.__lookup(key)
124        if val:
125            rv = 0, id(val), struct.pack(
126                memcacheConstants.GET_RES_FMT, val[0]) + str(val[2])
127        else:
128            rv=self._error(memcacheConstants.ERR_NOT_FOUND, 'Not found')
129        return rv
130
131    def handle_set(self, cmd, hdrs, key, cas, data):
132        print "Handling a set with", hdrs
133        val=self.__lookup(key)
134        exp, flags=hdrs
135        def f(val):
136            return self.__handle_unconditional_set(cmd, hdrs, key, data)
137        return self._withCAS(key, cas, f)
138
139    def handle_getq(self, cmd, hdrs, key, cas, data):
140        rv=self.handle_get(cmd, hdrs, key, cas, data)
141        if rv[0] == memcacheConstants.ERR_NOT_FOUND:
142            print "Swallowing miss"
143            rv = None
144        return rv
145
146    def __handle_unconditional_set(self, cmd, hdrs, key, data):
147        exp=hdrs[1]
148        # If it's going to expire soon, tell it to wait a while.
149        if exp == 0:
150            exp=float(2 ** 31)
151        self.storage[key]=(hdrs[0], time.time() + exp, data)
152        print "Stored", self.storage[key], "in", key
153        if key in self.held_keys:
154            del self.held_keys[key]
155        return 0, id(self.storage[key]), ''
156
157    def __mutation(self, cmd, hdrs, key, data, multiplier):
158        amount, initial, expiration=hdrs
159        rv=self._error(memcacheConstants.ERR_NOT_FOUND, 'Not found')
160        val=self.storage.get(key, None)
161        print "Mutating %s, hdrs=%s, val=%s %s" % (key, `hdrs`, `val`,
162            multiplier)
163        if val:
164            val = (val[0], val[1], max(0, long(val[2]) + (multiplier * amount)))
165            self.storage[key]=val
166            rv=0, id(val), str(val[2])
167        else:
168            if expiration != memcacheConstants.INCRDECR_SPECIAL:
169                self.storage[key]=(0, time.time() + expiration, initial)
170                rv=0, id(self.storage[key]), str(initial)
171        if rv[0] == 0:
172            rv = rv[0], rv[1], struct.pack(
173                memcacheConstants.INCRDECR_RES_FMT, long(rv[2]))
174        print "Returning", rv
175        return rv
176
177    def handle_incr(self, cmd, hdrs, key, cas, data):
178        return self.__mutation(cmd, hdrs, key, data, 1)
179
180    def handle_decr(self, cmd, hdrs, key, cas, data):
181        return self.__mutation(cmd, hdrs, key, data, -1)
182
183    def __has_hold(self, key):
184        rv=False
185        now=time.time()
186        print "Looking for hold of", key, "in", self.held_keys, "as of", now
187        if key in self.held_keys:
188            if time.time() > self.held_keys[key]:
189                del self.held_keys[key]
190            else:
191                rv=True
192        return rv
193
194    def handle_add(self, cmd, hdrs, key, cas, data):
195        rv=self._error(memcacheConstants.ERR_EXISTS, 'Data exists for key')
196        if key not in self.storage and not self.__has_hold(key):
197            rv=self.__handle_unconditional_set(cmd, hdrs, key, data)
198        return rv
199
200    def handle_replace(self, cmd, hdrs, key, cas, data):
201        rv=self._error(memcacheConstants.ERR_NOT_FOUND, 'Not found')
202        if key in self.storage and not self.__has_hold(key):
203            rv=self.__handle_unconditional_set(cmd, hdrs, key, data)
204        return rv
205
206    def handle_flush(self, cmd, hdrs, key, cas, data):
207        timebomb_delay=hdrs[0]
208        def f():
209            self.storage.clear()
210            self.held_keys.clear()
211            print "Flushed"
212        if timebomb_delay:
213            heapq.heappush(self.sched, (time.time() + timebomb_delay, f))
214        else:
215            f()
216        return 0, 0, ''
217
218    def handle_delete(self, cmd, hdrs, key, cas, data):
219        def f(val):
220            rv=self._error(memcacheConstants.ERR_NOT_FOUND, 'Not found')
221            if val:
222                del self.storage[key]
223                rv = 0, 0, ''
224            print "Deleted", key, hdrs[0]
225            if hdrs[0] > 0:
226                self.held_keys[key] = time.time() + hdrs[0]
227            return rv
228        return self._withCAS(key, cas, f)
229
230    def handle_version(self, cmd, hdrs, key, cas, data):
231        return 0, 0, "Python test memcached server %s" % VERSION
232
233    def _withCAS(self, key, cas, f):
234        val=self.storage.get(key, None)
235        if cas == 0 or (val and cas == id(val)):
236            rv=f(val)
237        elif val:
238            rv = self._error(memcacheConstants.ERR_EXISTS, 'Exists')
239        else:
240            rv = self._error(memcacheConstants.ERR_NOT_FOUND, 'Not found')
241        return rv
242
243    def handle_prepend(self, cmd, hdrs, key, cas, data):
244        def f(val):
245            self.storage[key]=(val[0], val[1], data + val[2])
246            return 0, id(self.storage[key]), ''
247        return self._withCAS(key, cas, f)
248
249    def handle_append(self, cmd, hdrs, key, cas, data):
250        def f(val):
251            self.storage[key]=(val[0], val[1], val[2] + data)
252            return 0, id(self.storage[key]), ''
253        return self._withCAS(key, cas, f)
254
255    def handle_sasl_mechs(self, cmd, hdrs, key, cas, data):
256        return 0, 0, 'PLAIN CRAM-MD5'
257
258    def handle_sasl_step(self, cmd, hdrs, key, cas, data):
259        assert key == 'CRAM-MD5'
260
261        u, resp = data.split(' ', 1)
262        expected = hmac.HMAC('testpass', self.challenge).hexdigest()
263
264        if u == 'testuser' and resp == expected:
265            print "Successful CRAM-MD5 auth."
266            return 0, 0, 'OK'
267        else:
268            print "Errored a CRAM-MD5 auth."
269            return self._error(memcacheConstants.ERR_AUTH, 'Auth error.')
270
271    def _handle_sasl_auth_plain(self, data):
272        foruser, user, passwd = data.split("\0")
273        if user == 'testuser' and passwd == 'testpass':
274            print "Successful plain auth"
275            return 0, 0, "OK"
276        else:
277            print "Bad username/password:  %s/%s" % (user, passwd)
278            return self._error(memcacheConstants.ERR_AUTH, 'Auth error.')
279
280    def _handle_sasl_auth_cram_md5(self, data):
281        assert data == ''
282        print "Issuing %s as a CRAM-MD5 challenge." % self.challenge
283        return memcacheConstants.ERR_AUTH_CONTINUE, 0, self.challenge
284
285    def handle_sasl_auth(self, cmd, hdrs, key, cas, data):
286        mech = key
287
288        if mech == 'PLAIN':
289            return self._handle_sasl_auth_plain(data)
290        elif mech == 'CRAM-MD5':
291            return self._handle_sasl_auth_cram_md5(data)
292        else:
293            print "Unhandled auth type:  %s" % mech
294            return self._error(memcacheConstants.ERR_AUTH, 'Auth error.')
295
296class MemcachedBinaryChannel(asyncore.dispatcher):
297    """A channel implementing the binary protocol for memcached."""
298
299    # Receive buffer size
300    BUFFER_SIZE = 4096
301
302    def __init__(self, channel, backend, wbuf=""):
303        asyncore.dispatcher.__init__(self, channel)
304        self.log_info("New bin connection from %s" % str(self.addr))
305        self.backend=backend
306        self.wbuf=wbuf
307        self.rbuf=""
308
309    def __hasEnoughBytes(self):
310        rv=False
311        if len(self.rbuf) >= MIN_RECV_PACKET:
312            magic, cmd, keylen, extralen, datatype, vb, remaining, opaque, cas=\
313                struct.unpack(REQ_PKT_FMT, self.rbuf[:MIN_RECV_PACKET])
314            rv = len(self.rbuf) - MIN_RECV_PACKET >= remaining
315        return rv
316
317    def processCommand(self, cmd, keylen, vb, cas, data):
318        return self.backend.processCommand(cmd, keylen, vb, cas, data)
319
320    def handle_read(self):
321        self.rbuf += self.recv(self.BUFFER_SIZE)
322        while self.__hasEnoughBytes():
323            magic, cmd, keylen, extralen, datatype, vb, remaining, opaque, cas=\
324                struct.unpack(REQ_PKT_FMT, self.rbuf[:MIN_RECV_PACKET])
325            assert magic == REQ_MAGIC_BYTE
326            assert keylen <= remaining, "Keylen is too big: %d > %d" \
327                % (keylen, remaining)
328            assert extralen == memcacheConstants.EXTRA_HDR_SIZES.get(cmd, 0), \
329                "Extralen is too large for cmd 0x%x: %d" % (cmd, extralen)
330            # Grab the data section of this request
331            data=self.rbuf[MIN_RECV_PACKET:MIN_RECV_PACKET+remaining]
332            assert len(data) == remaining
333            # Remove this request from the read buffer
334            self.rbuf=self.rbuf[MIN_RECV_PACKET+remaining:]
335            # Process the command
336            cmdVal = self.processCommand(cmd, keylen, vb, extralen, cas, data)
337            # Queue the response to the client if applicable.
338            if cmdVal:
339                try:
340                    status, cas, response = cmdVal
341                except ValueError:
342                    print "Got", cmdVal
343                    raise
344                dtype=0
345                extralen=memcacheConstants.EXTRA_HDR_SIZES.get(cmd, 0)
346                self.wbuf += struct.pack(RES_PKT_FMT,
347                    RES_MAGIC_BYTE, cmd, keylen,
348                    extralen, dtype, status,
349                    len(response), opaque, cas) + response
350
351    def writable(self):
352        return self.wbuf
353
354    def handle_write(self):
355        sent = self.send(self.wbuf)
356        self.wbuf = self.wbuf[sent:]
357
358    def handle_close(self):
359        self.log_info("Disconnected from %s" % str(self.addr))
360        self.close()
361
362class MemcachedServer(asyncore.dispatcher):
363    """A memcached server."""
364    def __init__(self, backend, handler, port=11211):
365        asyncore.dispatcher.__init__(self)
366
367        self.handler=handler
368        self.backend=backend
369
370        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
371        self.set_reuse_addr()
372        self.bind(("", port))
373        self.listen(5)
374        self.log_info("Listening on %d" % port)
375
376    def handle_accept(self):
377        channel, addr = self.accept()
378        self.handler(channel, self.backend)
379
380if __name__ == '__main__':
381    port = 11211
382    import sys
383    if sys.argv > 1:
384        port = int(sys.argv[1])
385    server = MemcachedServer(DictBackend(), MemcachedBinaryChannel, port=port)
386    asyncore.loop()
387