1 /* -*- Mode: C++; tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /*
3  *     Copyright 2010 Couchbase, Inc
4  *
5  *   Licensed under the Apache License, Version 2.0 (the "License");
6  *   you may not use this file except in compliance with the License.
7  *   You may obtain a copy of the License at
8  *
9  *       http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *   Unless required by applicable law or agreed to in writing, software
12  *   distributed under the License is distributed on an "AS IS" BASIS,
13  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *   See the License for the specific language governing permissions and
15  *   limitations under the License.
16  */
17 
18 #include "config.h"
19 
20 #include <algorithm>
21 #include <iostream>
22 #include <vector>
23 
24 #include "locks.h"
25 #include "syncobject.h"
26 #include "utility.h"
27 
28 #ifndef TESTS_MODULE_TESTS_THREADTESTS_H_
29 #define TESTS_MODULE_TESTS_THREADTESTS_H_ 1
30 
31 template <typename T>
32 class Generator {
33 public:
~Generator()34     virtual ~Generator() {}
35     virtual T operator()() = 0;
36 };
37 
38 class CountDownLatch {
39 public:
40 
CountDownLatch(int n=1)41     CountDownLatch(int n=1) : count(n) {}
42 
decr(void)43     void decr(void) {
44         std::unique_lock<std::mutex> lh(so);
45         --count;
46         so.notify_all();
47     }
48 
wait(void)49     void wait(void) {
50         std::unique_lock<std::mutex> lh(so);
51         while (count > 0) {
52             so.wait(lh);
53         }
54     }
55 
56 private:
57     int count;
58     SyncObject so;
59 
60     DISALLOW_COPY_AND_ASSIGN(CountDownLatch);
61 };
62 
63 template <typename T> class SyncTestThread;
64 
65 template <typename T>
launch_sync_test_thread(void *arg)66 static void launch_sync_test_thread(void *arg) {
67     SyncTestThread<T> *stt = static_cast<SyncTestThread<T>*>(arg);
68     stt->run();
69 }
70 
71 extern "C" {
72    typedef void (*CB_THREAD_MAIN)(void *);
73 }
74 
75 template <typename T>
76 class SyncTestThread {
77 public:
78 
SyncTestThread(CountDownLatch *s, CountDownLatch *p, Generator<T> *testGen)79     SyncTestThread(CountDownLatch *s, CountDownLatch *p, Generator<T> *testGen) :
80         startingLine(s), pistol(p), gen(testGen) {}
81 
SyncTestThread(const SyncTestThread &other)82     SyncTestThread(const SyncTestThread &other) :
83         startingLine(other.startingLine),
84         pistol(other.pistol),
85         gen(other.gen) {}
86 
start(void)87     void start(void) {
88         if (cb_create_thread(&thread, (CB_THREAD_MAIN)( launch_sync_test_thread<T> ), this, 0) != 0) {
89             throw std::runtime_error("Error initializing thread");
90         }
91     }
92 
run(void)93     void run(void) {
94         startingLine->decr();
95         pistol->wait();
96         result = (*gen)();
97     }
98 
join(void)99     void join(void) {
100         if (cb_join_thread(thread) != 0) {
101             throw std::runtime_error("Failed to join.");
102         }
103     }
104 
getResult(void) const105     const T getResult(void) const { return result; };
106 
107 private:
108     CountDownLatch *startingLine;
109     CountDownLatch *pistol;
110     Generator<T>   *gen;
111 
112     T         result;
113     cb_thread_t thread;
114 };
115 
116 template <typename T>
starter(SyncTestThread<T> &t)117 static void starter(SyncTestThread<T> &t) { t.start(); }
118 
119 template <typename T>
waiter(SyncTestThread<T> &t)120 static void waiter(SyncTestThread<T> &t) { t.join(); }
121 
122 template <typename T>
getCompletedThreads(size_t n, Generator<T> *gen)123 std::vector<T> getCompletedThreads(size_t n, Generator<T> *gen) {
124     CountDownLatch startingLine(n), pistol(1);
125 
126     SyncTestThread<T> proto(&startingLine, &pistol, gen);
127     std::vector<SyncTestThread<T> > threads(n, proto);
128     cb_assert(threads.size() == n);
129     std::for_each(threads.begin(), threads.end(), starter<T>);
130 
131     startingLine.wait();
132     pistol.decr();
133 
134     std::for_each(threads.begin(), threads.end(), waiter<T>);
135 
136     std::vector<T> results;
137     typename std::vector<SyncTestThread<T> >::iterator it;
138     for (it = threads.begin(); it != threads.end(); ++it) {
139         results.push_back(it->getResult());
140     }
141 
142     return results;
143 }
144 
145 #endif  // TESTS_MODULE_TESTS_THREADTESTS_H_
146