1import threading
2import socketserver
3import socket
4
5
6class MockMemSession(socketserver.BaseRequestHandler):
7    def handle(self):
8        data = self.request.recv(1024)
9        self.server.append(data)
10        self.server.server.basic_handler(self.server.data, self.request)
11
12
13class MockMemcachedServerInternal(socketserver.ThreadingMixIn, socketserver.TCPServer):
14    def __init__(self, address, port, request_class, server):
15        super(MockMemcachedServerInternal, self).__init__(server_address=(address, port),
16                                                          RequestHandlerClass=request_class)
17        self.data = []
18        self.server = server
19
20    def append(self, data):
21        self.data.append(data)
22
23    def reset(self):
24        self.data = []
25
26    def log(self, msg):
27        self.server.log(msg)
28
29    def pop(self, n):
30        data = b''
31        if n >= len(self.data):
32            data = self.data[:]
33            self.data = []
34        else:
35            data = self.data[:n]
36            self.data = self.data[n:]
37        return data
38
39
40class MockMemcachedServer:
41    def __init__(self, address='127.0.0.1', port=52135, debug=False, handler=None):
42        socketserver.TCPServer.allow_reuse_address = True
43        self.debug = debug
44        self.address = address
45        self.port = port
46        self.server = MockMemcachedServerInternal(address, port, MockMemSession, self)
47        self.server_thread = threading.Thread(target=self.server.serve_forever)
48        self.test_handler = handler
49        self.running = False
50
51    def set_debug(self, debug):
52        self.debug = debug
53
54    def set_handler(self, handler):
55        self.test_handler = handler
56
57    def start(self):
58        self.running = True
59        self.log('Starting server thread at {}:{}'.format(self.address, self.port))
60        self.server_thread.start()
61
62    def basic_handler(self, data, req):
63        self.log('Data: {}'.format(data))
64        if self.test_handler:
65            self.test_handler(data, req, self.debug)
66
67    def stop(self):
68        self.log('Shut down server')
69        if self.running:
70            self.server.shutdown()
71        self.log('Socket close')
72        self.server.socket.close()
73        self.log('Joining')
74        if self.running:
75            self.server_thread.join()
76
77        self.running = False
78        self.log('Thread finished')
79
80    def reset(self):
81        self.test_handler = None
82        self.server.reset()
83
84    def get_host_address(self):
85        return self.address, self.port
86
87    def log(self, msg):
88        if self.debug:
89            print(msg)
90
91
92def fake_conn():
93    for info in socket.getaddrinfo('127.0.0.1', 5235, socket.AF_UNSPEC,
94                                   socket.SOCK_STREAM):
95        _family, socktype, proto, _, sockaddr = info
96        try:
97            sock = socket.socket(_family, socktype, proto)
98            sock.settimeout(10)
99            s = sock
100            s.connect_ex(sockaddr)
101            return s
102        except socket.error as sock_error:
103            # If we get here socket objects will be close()d via
104            # garbage collection.
105            pass
106    else:
107        # Didn't break from the loop, re-raise the last error
108        raise sock_error
109
110