1import sys
2import string
3import socket
4import select
5import unittest
6import threading
7import time
8import re
9import struct
10
11from memcacheConstants import REQ_MAGIC_BYTE, RES_MAGIC_BYTE
12from memcacheConstants import REQ_PKT_FMT, RES_PKT_FMT, MIN_RECV_PACKET
13from memcacheConstants import SET_PKT_FMT, DEL_PKT_FMT, INCRDECR_RES_FMT
14
15import memcacheConstants
16
17def debug(level, x):
18    if level < 1:
19        print(x)
20
21# A fake memcached server.
22#
23class MockServer(threading.Thread):
24    def __init__(self, port):
25        threading.Thread.__init__(self)
26        self.daemon = True
27        self.host     = ''
28        self.port     = port
29        self.backlog  = 5
30        self.server   = None
31        self.running  = False
32        self.sessions = {}
33
34    def closeSessions(self):
35        sessions = self.sessions # Snapshot to avoid concurrent iteration mods.
36        self.sessions = {}
37
38        for k in sessions:
39            sessions[k].close()
40
41    def close(self):
42        self.running = False
43        self.closeSessions()
44        if self.server:
45            self.server.close()
46        self.server = None
47
48    def run(self):
49        self.running = True
50        try:
51            self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
52            self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
53            self.server.bind((self.host, self.port))
54            self.server.listen(self.backlog)
55
56            while self.running:
57                debug(0, "MockServer running " + str(self.port))
58                client, address = self.server.accept()
59                c = MockSession(client, address, self)
60                debug(0, "MockServer accepted " + str(self.port))
61                self.sessions[len(self.sessions)] = c
62                c.start()
63
64        except KeyboardInterrupt:
65            self.close()
66            raise
67        except socket.error, (value, message):
68            self.close()
69            debug(1, "MockServer socket error: " + message)
70            sys.exit(1)
71
72        self.close()
73
74# A session in the fake memcached server.
75#
76class MockSession(threading.Thread):
77    def __init__(self, client, address, server, recvlen_in=1024):
78        threading.Thread.__init__(self)
79        self.daemon = True
80        self.server  = server
81        self.client  = client
82        self.address = address
83        self.recvlen = recvlen_in
84        self.running     = 0
85        self.running_max = 10
86        self.received = []
87
88    def run(self):
89        input = [self.client]
90
91        try:
92            self.running = 1
93            while (self.running > 0 and
94                   self.running < self.running_max):
95                debug(1, "MockSession running (" + str(self.running) + ")")
96                self.running = self.running + 1
97
98                iready, oready, eready = select.select(input, [], [], 1)
99                if len(eready) > 0:
100                    debug(1, "MockSession select eready...")
101                    self.running = 0
102                elif len(iready) > 0:
103                    debug(1, "MockSession recv...")
104                    data = self.client.recv(self.recvlen)
105                    debug(1, "MockSession recv done:" + data)
106
107                    if data and len(data) > 0:
108                        self.received.append(data)
109                    else:
110                        debug(1, "MockSession recv no data")
111                        self.close()
112
113        except KeyboardInterrupt:
114            raise
115        except:
116            1
117
118        if self.running >= self.running_max:
119            debug(1, "MockSession running too long, shutting down")
120
121        debug(1, "MockSession closing")
122        self.close()
123
124    def close(self):
125        debug(1, "MockSession close")
126        self.running = 0
127        if self.client:
128            self.client.close()
129        self.client = None
130
131# Start a fake memcached server...
132#
133sys.setcheckinterval(0)
134g_mock_server_port = 11311
135g_mock_server = MockServer(g_mock_server_port)
136g_mock_server.start()
137time.sleep(1)
138
139class ProxyClientBase(unittest.TestCase):
140    vbucketId = 0
141
142    def __init__(self, x):
143        unittest.TestCase.__init__(self, x)
144
145        # These tests assume a moxi proxy is running at
146        # the self.proxy_port and is forwarding requests
147        # to our fake memcached server.
148        #
149        # TODO: Fork a moxi proxy like the perl tests.
150        #
151        self.proxy_port = 11333
152        self.clients = {}
153
154    def mock_server(self):
155        global g_mock_server
156        return g_mock_server
157
158    def setUp(self):
159        """setUp"""
160
161    def tearDown(self):
162        self.mock_close()
163        for k in self.clients:
164            self.client_close(k)
165        self.clients = []
166
167    def client_connect(self, idx=0):
168        c = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
169        c.connect(("127.0.0.1", self.proxy_port))
170        self.clients[idx] = c
171        return c
172
173    def client_send(self, what, idx=0):
174        debug(1, "client sending " + what)
175        self.clients[idx].send(what)
176
177    def mock_send(self, what, session_idx=0):
178        debug(1, "mock sending " + what)
179
180        session = self.mock_server().sessions[session_idx]
181
182        self.assertTrue(session.client is not None)
183
184        session.client.send(what)
185
186    def client_recv(self, what, idx=0, num_bytes=1024):
187        debug(1, "client_recv expect: '" + what + "'")
188
189        s = self.clients[idx].recv(num_bytes)
190
191        debug(1, "client_recv actual: '" + s + "'");
192
193        self.assertTrue(what == s or re.match(what, s))
194
195    def mock_session(self, session_idx=0):
196        wait_max = 5
197
198        i = 1
199        while len(self.mock_server().sessions) <= session_idx and i < wait_max:
200            time.sleep(i)
201            i = i * 2
202
203        if len(self.mock_server().sessions) <= session_idx and i >= wait_max:
204            debug(1, "waiting too long for mock_session " + str(i))
205
206        return self.mock_server().sessions[session_idx]
207
208    def mock_recv_message(self, session_idx=0):
209        session = self.mock_session(session_idx)
210
211        wait_max = 5
212        i = 1
213        while len(session.received) <= 0 and i < wait_max:
214            debug(1, "sleeping waiting for mock_recv " + str(i))
215            time.sleep(i)
216            i = i * 2
217
218        if len(session.received) <= 0 and i >= wait_max:
219            debug(1, "waiting too long for mock_recv " + str(i))
220
221        message = ""
222        if len(session.received) > 0:
223            message = session.received.pop(0)
224
225        return message
226
227    def mock_recv(self, what, session_idx=0):
228        # Useful for ascii messages.
229        debug(1, "mock_recv expect: " + what)
230        message = self.mock_recv_message(session_idx)
231        debug(1, "mock_recv actual: " + message);
232        self.assertTrue(what == message or re.match(what, message) is not None)
233
234    def wait(self, x):
235        debug(1, "wait " + str(x))
236        time.sleep(0.01 * x)
237
238    def client_close(self, idx=0):
239        if self.clients[idx]:
240            self.clients[idx].close()
241        self.clients[idx] = None
242
243    def mock_close(self):
244        if self.mock_server():
245            self.mock_server().closeSessions()
246
247    def mock_quiet(self, session_idx=0):
248        if len(self.mock_server().sessions) <= session_idx:
249            return True
250
251        session = self.mock_server().sessions[session_idx]
252
253        return len(session.received) <= 0
254
255    def dump(self, header, prefix=''):
256        length = len(header)
257        if length > MIN_RECV_PACKET:
258            length = MIN_RECV_PACKET
259        r = ''
260        for i in range(length):
261            c = header[i]
262            if i % 4 == 0:
263                r = r + prefix + ' '
264            r = r + ('0x%02X ' % ord(c))
265            if i % 4 == 3 and i > 0:
266                r = r + '\n'
267        return r
268
269    def packReq(self, cmd, reserved=0, key='', val='', opaque=0, extraHeader='', cas=0):
270        dtype=0
271        msg=struct.pack(REQ_PKT_FMT, REQ_MAGIC_BYTE,
272            cmd, len(key), len(extraHeader), dtype, reserved,
273                len(key) + len(extraHeader) + len(val), opaque, cas)
274        return msg + extraHeader + key + val
275
276    def packRes(self, cmd, status=0, key='', val='', opaque=0, extraHeader='', cas=0):
277        dtype=0
278        msg=struct.pack(RES_PKT_FMT, RES_MAGIC_BYTE,
279            cmd, len(key), len(extraHeader), dtype, status,
280                len(key) + len(extraHeader) + len(val), opaque, cas)
281        return msg + extraHeader + key + val
282
283