1#!/usr/bin/env python 2""" 3A memcached test server. 4 5Copyright (c) 2007 Dustin Sallings <dustin@spy.net> 6""" 7 8import asyncore 9import random 10import string 11import socket 12import struct 13import time 14import hmac 15import heapq 16 17import memcacheConstants 18 19from memcacheConstants import MIN_RECV_PACKET, REQ_PKT_FMT, RES_PKT_FMT 20from memcacheConstants import INCRDECR_RES_FMT 21from memcacheConstants import REQ_MAGIC_BYTE, RES_MAGIC_BYTE, EXTRA_HDR_FMTS 22 23VERSION="1.0" 24 25class BaseBackend(object): 26 """Higher-level backend (processes commands and stuff).""" 27 28 # Command IDs to method names. This is used to build a dispatch dict on 29 # the fly. 30 CMDS={ 31 memcacheConstants.CMD_GET: 'handle_get', 32 memcacheConstants.CMD_GETQ: 'handle_getq', 33 memcacheConstants.CMD_SET: 'handle_set', 34 memcacheConstants.CMD_ADD: 'handle_add', 35 memcacheConstants.CMD_REPLACE: 'handle_replace', 36 memcacheConstants.CMD_DELETE: 'handle_delete', 37 memcacheConstants.CMD_INCR: 'handle_incr', 38 memcacheConstants.CMD_DECR: 'handle_decr', 39 memcacheConstants.CMD_QUIT: 'handle_quit', 40 memcacheConstants.CMD_FLUSH: 'handle_flush', 41 memcacheConstants.CMD_NOOP: 'handle_noop', 42 memcacheConstants.CMD_VERSION: 'handle_version', 43 memcacheConstants.CMD_APPEND: 'handle_append', 44 memcacheConstants.CMD_PREPEND: 'handle_prepend', 45 memcacheConstants.CMD_SASL_LIST_MECHS: 'handle_sasl_mechs', 46 memcacheConstants.CMD_SASL_AUTH: 'handle_sasl_auth', 47 memcacheConstants.CMD_SASL_STEP: 'handle_sasl_step', 48 } 49 50 def __init__(self): 51 self.handlers={} 52 self.sched=[] 53 54 for id, method in self.CMDS.iteritems(): 55 self.handlers[id]=getattr(self, method, self.handle_unknown) 56 57 def _splitKeys(self, fmt, keylen, data): 58 """Split the given data into the headers as specified in the given 59 format, the key, and the data. 60 61 Return (hdrTuple, key, data)""" 62 hdrSize=struct.calcsize(fmt) 63 assert hdrSize <= len(data), "Data too short for " + fmt + ': ' + `data` 64 hdr=struct.unpack(fmt, data[:hdrSize]) 65 assert len(data) >= hdrSize + keylen 66 key=data[hdrSize:keylen+hdrSize] 67 assert len(key) == keylen, "len(%s) == %d, expected %d" \ 68 % (key, len(key), keylen) 69 val=data[keylen+hdrSize:] 70 return hdr, key, val 71 72 def _error(self, which, msg): 73 return which, 0, msg 74 75 def processCommand(self, cmd, keylen, vb, cas, data): 76 """Entry point for command processing. Lower level protocol 77 implementations deliver values here.""" 78 79 now=time.time() 80 while self.sched and self.sched[0][0] <= now: 81 print "Running delayed job." 82 heapq.heappop(self.sched)[1]() 83 84 hdrs, key, val=self._splitKeys(EXTRA_HDR_FMTS.get(cmd, ''), 85 keylen, data) 86 87 return self.handlers.get(cmd, self.handle_unknown)(cmd, hdrs, key, 88 cas, val) 89 90 def handle_noop(self, cmd, hdrs, key, cas, data): 91 """Handle a noop""" 92 print "Noop" 93 return 0, 0, '' 94 95 def handle_unknown(self, cmd, hdrs, key, cas, data): 96 """invoked for any unknown command.""" 97 return self._error(memcacheConstants.ERR_UNKNOWN_CMD, 98 "The command %d is unknown" % cmd) 99 100class DictBackend(BaseBackend): 101 """Sample backend implementation with a non-expiring dict.""" 102 103 def __init__(self): 104 super(DictBackend, self).__init__() 105 self.storage={} 106 self.held_keys={} 107 self.challenge = ''.join(random.sample(string.ascii_letters 108 + string.digits, 32)) 109 110 def __lookup(self, key): 111 rv=self.storage.get(key, None) 112 if rv: 113 now=time.time() 114 if now >= rv[1]: 115 print key, "expired" 116 del self.storage[key] 117 rv=None 118 else: 119 print "Miss looking up", key 120 return rv 121 122 def handle_get(self, cmd, hdrs, key, cas, data): 123 val=self.__lookup(key) 124 if val: 125 rv = 0, id(val), struct.pack( 126 memcacheConstants.GET_RES_FMT, val[0]) + str(val[2]) 127 else: 128 rv=self._error(memcacheConstants.ERR_NOT_FOUND, 'Not found') 129 return rv 130 131 def handle_set(self, cmd, hdrs, key, cas, data): 132 print "Handling a set with", hdrs 133 val=self.__lookup(key) 134 exp, flags=hdrs 135 def f(val): 136 return self.__handle_unconditional_set(cmd, hdrs, key, data) 137 return self._withCAS(key, cas, f) 138 139 def handle_getq(self, cmd, hdrs, key, cas, data): 140 rv=self.handle_get(cmd, hdrs, key, cas, data) 141 if rv[0] == memcacheConstants.ERR_NOT_FOUND: 142 print "Swallowing miss" 143 rv = None 144 return rv 145 146 def __handle_unconditional_set(self, cmd, hdrs, key, data): 147 exp=hdrs[1] 148 # If it's going to expire soon, tell it to wait a while. 149 if exp == 0: 150 exp=float(2 ** 31) 151 self.storage[key]=(hdrs[0], time.time() + exp, data) 152 print "Stored", self.storage[key], "in", key 153 if key in self.held_keys: 154 del self.held_keys[key] 155 return 0, id(self.storage[key]), '' 156 157 def __mutation(self, cmd, hdrs, key, data, multiplier): 158 amount, initial, expiration=hdrs 159 rv=self._error(memcacheConstants.ERR_NOT_FOUND, 'Not found') 160 val=self.storage.get(key, None) 161 print "Mutating %s, hdrs=%s, val=%s %s" % (key, `hdrs`, `val`, 162 multiplier) 163 if val: 164 val = (val[0], val[1], max(0, long(val[2]) + (multiplier * amount))) 165 self.storage[key]=val 166 rv=0, id(val), str(val[2]) 167 else: 168 if expiration != memcacheConstants.INCRDECR_SPECIAL: 169 self.storage[key]=(0, time.time() + expiration, initial) 170 rv=0, id(self.storage[key]), str(initial) 171 if rv[0] == 0: 172 rv = rv[0], rv[1], struct.pack( 173 memcacheConstants.INCRDECR_RES_FMT, long(rv[2])) 174 print "Returning", rv 175 return rv 176 177 def handle_incr(self, cmd, hdrs, key, cas, data): 178 return self.__mutation(cmd, hdrs, key, data, 1) 179 180 def handle_decr(self, cmd, hdrs, key, cas, data): 181 return self.__mutation(cmd, hdrs, key, data, -1) 182 183 def __has_hold(self, key): 184 rv=False 185 now=time.time() 186 print "Looking for hold of", key, "in", self.held_keys, "as of", now 187 if key in self.held_keys: 188 if time.time() > self.held_keys[key]: 189 del self.held_keys[key] 190 else: 191 rv=True 192 return rv 193 194 def handle_add(self, cmd, hdrs, key, cas, data): 195 rv=self._error(memcacheConstants.ERR_EXISTS, 'Data exists for key') 196 if key not in self.storage and not self.__has_hold(key): 197 rv=self.__handle_unconditional_set(cmd, hdrs, key, data) 198 return rv 199 200 def handle_replace(self, cmd, hdrs, key, cas, data): 201 rv=self._error(memcacheConstants.ERR_NOT_FOUND, 'Not found') 202 if key in self.storage and not self.__has_hold(key): 203 rv=self.__handle_unconditional_set(cmd, hdrs, key, data) 204 return rv 205 206 def handle_flush(self, cmd, hdrs, key, cas, data): 207 timebomb_delay=hdrs[0] 208 def f(): 209 self.storage.clear() 210 self.held_keys.clear() 211 print "Flushed" 212 if timebomb_delay: 213 heapq.heappush(self.sched, (time.time() + timebomb_delay, f)) 214 else: 215 f() 216 return 0, 0, '' 217 218 def handle_delete(self, cmd, hdrs, key, cas, data): 219 def f(val): 220 rv=self._error(memcacheConstants.ERR_NOT_FOUND, 'Not found') 221 if val: 222 del self.storage[key] 223 rv = 0, 0, '' 224 print "Deleted", key, hdrs[0] 225 if hdrs[0] > 0: 226 self.held_keys[key] = time.time() + hdrs[0] 227 return rv 228 return self._withCAS(key, cas, f) 229 230 def handle_version(self, cmd, hdrs, key, cas, data): 231 return 0, 0, "Python test memcached server %s" % VERSION 232 233 def _withCAS(self, key, cas, f): 234 val=self.storage.get(key, None) 235 if cas == 0 or (val and cas == id(val)): 236 rv=f(val) 237 elif val: 238 rv = self._error(memcacheConstants.ERR_EXISTS, 'Exists') 239 else: 240 rv = self._error(memcacheConstants.ERR_NOT_FOUND, 'Not found') 241 return rv 242 243 def handle_prepend(self, cmd, hdrs, key, cas, data): 244 def f(val): 245 self.storage[key]=(val[0], val[1], data + val[2]) 246 return 0, id(self.storage[key]), '' 247 return self._withCAS(key, cas, f) 248 249 def handle_append(self, cmd, hdrs, key, cas, data): 250 def f(val): 251 self.storage[key]=(val[0], val[1], val[2] + data) 252 return 0, id(self.storage[key]), '' 253 return self._withCAS(key, cas, f) 254 255 def handle_sasl_mechs(self, cmd, hdrs, key, cas, data): 256 return 0, 0, 'PLAIN CRAM-MD5' 257 258 def handle_sasl_step(self, cmd, hdrs, key, cas, data): 259 assert key == 'CRAM-MD5' 260 261 u, resp = data.split(' ', 1) 262 expected = hmac.HMAC('testpass', self.challenge).hexdigest() 263 264 if u == 'testuser' and resp == expected: 265 print "Successful CRAM-MD5 auth." 266 return 0, 0, 'OK' 267 else: 268 print "Errored a CRAM-MD5 auth." 269 return self._error(memcacheConstants.ERR_AUTH, 'Auth error.') 270 271 def _handle_sasl_auth_plain(self, data): 272 foruser, user, passwd = data.split("\0") 273 if user == 'testuser' and passwd == 'testpass': 274 print "Successful plain auth" 275 return 0, 0, "OK" 276 else: 277 print "Bad username/password: %s/%s" % (user, passwd) 278 return self._error(memcacheConstants.ERR_AUTH, 'Auth error.') 279 280 def _handle_sasl_auth_cram_md5(self, data): 281 assert data == '' 282 print "Issuing %s as a CRAM-MD5 challenge." % self.challenge 283 return memcacheConstants.ERR_AUTH_CONTINUE, 0, self.challenge 284 285 def handle_sasl_auth(self, cmd, hdrs, key, cas, data): 286 mech = key 287 288 if mech == 'PLAIN': 289 return self._handle_sasl_auth_plain(data) 290 elif mech == 'CRAM-MD5': 291 return self._handle_sasl_auth_cram_md5(data) 292 else: 293 print "Unhandled auth type: %s" % mech 294 return self._error(memcacheConstants.ERR_AUTH, 'Auth error.') 295 296class MemcachedBinaryChannel(asyncore.dispatcher): 297 """A channel implementing the binary protocol for memcached.""" 298 299 # Receive buffer size 300 BUFFER_SIZE = 4096 301 302 def __init__(self, channel, backend, wbuf=""): 303 asyncore.dispatcher.__init__(self, channel) 304 self.log_info("New bin connection from %s" % str(self.addr)) 305 self.backend=backend 306 self.wbuf=wbuf 307 self.rbuf="" 308 309 def __hasEnoughBytes(self): 310 rv=False 311 if len(self.rbuf) >= MIN_RECV_PACKET: 312 magic, cmd, keylen, extralen, datatype, vb, remaining, opaque, cas=\ 313 struct.unpack(REQ_PKT_FMT, self.rbuf[:MIN_RECV_PACKET]) 314 rv = len(self.rbuf) - MIN_RECV_PACKET >= remaining 315 return rv 316 317 def processCommand(self, cmd, keylen, vb, cas, data): 318 return self.backend.processCommand(cmd, keylen, vb, cas, data) 319 320 def handle_read(self): 321 self.rbuf += self.recv(self.BUFFER_SIZE) 322 while self.__hasEnoughBytes(): 323 magic, cmd, keylen, extralen, datatype, vb, remaining, opaque, cas=\ 324 struct.unpack(REQ_PKT_FMT, self.rbuf[:MIN_RECV_PACKET]) 325 assert magic == REQ_MAGIC_BYTE 326 assert keylen <= remaining, "Keylen is too big: %d > %d" \ 327 % (keylen, remaining) 328 assert extralen == memcacheConstants.EXTRA_HDR_SIZES.get(cmd, 0), \ 329 "Extralen is too large for cmd 0x%x: %d" % (cmd, extralen) 330 # Grab the data section of this request 331 data=self.rbuf[MIN_RECV_PACKET:MIN_RECV_PACKET+remaining] 332 assert len(data) == remaining 333 # Remove this request from the read buffer 334 self.rbuf=self.rbuf[MIN_RECV_PACKET+remaining:] 335 # Process the command 336 cmdVal = self.processCommand(cmd, keylen, vb, extralen, cas, data) 337 # Queue the response to the client if applicable. 338 if cmdVal: 339 try: 340 status, cas, response = cmdVal 341 except ValueError: 342 print "Got", cmdVal 343 raise 344 dtype=0 345 extralen=memcacheConstants.EXTRA_HDR_SIZES.get(cmd, 0) 346 self.wbuf += struct.pack(RES_PKT_FMT, 347 RES_MAGIC_BYTE, cmd, keylen, 348 extralen, dtype, status, 349 len(response), opaque, cas) + response 350 351 def writable(self): 352 return self.wbuf 353 354 def handle_write(self): 355 sent = self.send(self.wbuf) 356 self.wbuf = self.wbuf[sent:] 357 358 def handle_close(self): 359 self.log_info("Disconnected from %s" % str(self.addr)) 360 self.close() 361 362class MemcachedServer(asyncore.dispatcher): 363 """A memcached server.""" 364 def __init__(self, backend, handler, port=11211): 365 asyncore.dispatcher.__init__(self) 366 367 self.handler=handler 368 self.backend=backend 369 370 self.create_socket(socket.AF_INET, socket.SOCK_STREAM) 371 self.set_reuse_addr() 372 self.bind(("", port)) 373 self.listen(5) 374 self.log_info("Listening on %d" % port) 375 376 def handle_accept(self): 377 channel, addr = self.accept() 378 self.handler(channel, self.backend) 379 380if __name__ == '__main__': 381 port = 11211 382 import sys 383 if sys.argv > 1: 384 port = int(sys.argv[1]) 385 server = MemcachedServer(DictBackend(), MemcachedBinaryChannel, port=port) 386 asyncore.loop() 387