1import threading
2import random
3import zlib
4import time
5import copy
6
7
8class KVStore(object):
9    def __init__(self, num_locks=16):
10        self.num_locks = num_locks
11        self.reset()
12
13    def reset(self):
14        self.cache = {}
15        for itr in range(self.num_locks):
16            self.cache[itr] = {"lock": threading.Lock(),
17                               "partition": Partition(itr)}
18
19    def acquire_partition(self, key):
20        partition = self.cache[self._hash(key)]
21        partition["lock"].acquire()
22        return partition["partition"]
23
24    def acquire_partitions(self, keys):
25        part_obj_keys = {}
26        for key in keys:
27            partition = self.cache[self._hash(key)]
28            partition_obj = partition["partition"]
29            if partition_obj not in part_obj_keys:
30                partition["lock"].acquire()
31                part_obj_keys[partition_obj] = []
32            part_obj_keys[partition_obj].append(key)
33        return part_obj_keys
34
35    def release_partitions(self, partition_objs):
36        for partition_obj in partition_objs:
37            partition = self.cache[partition_obj.part_id]
38            partition["lock"].release()
39
40    def release_partition(self, key):
41        if isinstance(key, str):
42            self.cache[self._hash(key)]["lock"].release()
43        elif isinstance(key, int):
44            self.cache[key]["lock"].release()
45        else:
46            raise(Exception("Bad key"))
47
48    def acquire_random_partition(self, has_valid=True):
49        seed = random.choice(range(self.num_locks))
50        for itr in range(self.num_locks):
51            part_num = (seed + itr) % self.num_locks
52            self.cache[part_num]["lock"].acquire()
53            if has_valid and self.cache[part_num]["partition"].has_valid_keys():
54                return self.cache[part_num]["partition"], part_num
55            if not has_valid and self.cache[part_num]["partition"].has_deleted_keys():
56                return self.cache[part_num]["partition"], part_num
57            self.cache[part_num]["lock"].release()
58        return None, None
59
60    def key_set(self):
61        valid_keys = []
62        deleted_keys = []
63        for itr in range(self.num_locks):
64            self.cache[itr]["lock"].acquire()
65            partition = self.cache[itr]["partition"]
66            valid_keys.extend(partition.valid_key_set())
67            deleted_keys.extend(partition.deleted_key_set())
68            self.cache[itr]["lock"].release()
69        return valid_keys, deleted_keys
70
71    def __len__(self):
72        return sum([len(self.cache[itr]["partition"]) for itr in range(self.num_locks)])
73
74    def _hash(self, key):
75        return zlib.crc32(key) % self.num_locks
76
77
78class Partition(object):
79    def __init__(self, part_id):
80        self.part_id = part_id
81        self.__valid = {}
82        self.__deleted = {}
83        self.__expired_keys = []
84
85    def set(self, key, value, exp=0, flag=0):
86        if key in self.__deleted:
87            del self.__deleted[key]
88        if key in self.__expired_keys:
89            self.__expired_keys.remove(key)
90        if exp != 0:
91            exp = (time.time() + exp)
92        self.__valid[key] = {"value": value,
93                           "expires": exp,
94                           "flag": flag}
95
96    def delete(self, key):
97        if key in self.__valid:
98            self.__deleted[key] = self.__valid[key]["value"]
99            del self.__valid[key]
100
101    def get_key(self, key):
102        return self.__valid.get(key)
103
104    def get_valid(self, key):
105        self.__expire_key(key)
106        if key in self.__valid:
107            return self.__valid[key]["value"]
108        return None
109
110    def get_deleted(self, key):
111        self.__expire_key(key)
112        return self.__deleted.get(key)
113
114    def get_random_valid_key(self):
115        try:
116            key = random.choice(self.valid_key_set())
117            return key
118        except IndexError:
119            return None
120
121    def get_random_deleted_key(self):
122        try:
123            return random.choice(self.__deleted.keys())
124        except IndexError:
125            return None
126
127    def get_flag(self, key):
128        self.__expire_key(key)
129        if key in self.__valid:
130            return self.__valid[key]["flag"]
131        return None
132
133    def valid_key_set(self):
134        valid_keys = copy.copy(self.__valid.keys())
135        [self.__expire_key(key) for key in valid_keys]
136        return self.__valid.keys()
137
138    def deleted_key_set(self):
139        valid_keys = copy.copy(self.__valid.keys())
140        [self.__expire_key(key) for key in valid_keys]
141        return self.__deleted.keys()
142
143    def expired_key_set(self):
144        valid_keys = copy.copy(self.__valid.keys())
145        [self.__expire_key(key) for key in valid_keys]
146        return self.__expired_keys
147
148    def has_valid_keys(self):
149        return len(self.__valid) > 0
150
151    def has_deleted_keys(self):
152        return len(self.__deleted) > 0
153
154    def __expire_key(self, key):
155        if key in self.__valid:
156            if self.__valid[key]["expires"] != 0 and self.__valid[key]["expires"] < time.time():
157                self.__deleted[key] = self.__valid[key]["value"]
158                self.__expired_keys.append(key)
159                del self.__valid[key]
160
161    def expired(self, key):
162        if key not in self.__valid and key not in self.__deleted:
163            raise Exception("Key: %s is not a valid key" % key)
164        self.__expire_key(key)
165        return key in self.__expired_keys
166
167    def __len__(self):
168        [self.__expire_key(key) for key in self.__valid.keys()]
169        return len(self.__valid.keys())
170
171    def __eq__(self, other):
172        if isinstance(other, Partition):
173            return self.part_id == other.part_id
174        return False
175
176    def __hash__(self):
177        return self.part_id.__hash__()
178