1/* -*- Mode: C++; tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*- */
2/*
3 *     Copyright 2010 Couchbase, Inc
4 *
5 *   Licensed under the Apache License, Version 2.0 (the "License");
6 *   you may not use this file except in compliance with the License.
7 *   You may obtain a copy of the License at
8 *
9 *       http://www.apache.org/licenses/LICENSE-2.0
10 *
11 *   Unless required by applicable law or agreed to in writing, software
12 *   distributed under the License is distributed on an "AS IS" BASIS,
13 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *   See the License for the specific language governing permissions and
15 *   limitations under the License.
16 */
17
18#include "config.h"
19
20#include <algorithm>
21#include <iostream>
22#include <vector>
23
24#include "locks.h"
25
26#ifndef TESTS_MODULE_TESTS_THREADTESTS_H_
27#define TESTS_MODULE_TESTS_THREADTESTS_H_ 1
28
29template <typename T>
30class Generator {
31public:
32    virtual ~Generator() {}
33    virtual T operator()() = 0;
34};
35
36class CountDownLatch {
37public:
38
39    CountDownLatch(int n=1) : count(n) {}
40
41    void decr(void) {
42        LockHolder lh(so);
43        --count;
44        so.notify();
45    }
46
47    void wait(void) {
48        LockHolder lh(so);
49        while (count > 0) {
50            so.wait();
51        }
52    }
53
54private:
55    int count;
56    SyncObject so;
57
58    DISALLOW_COPY_AND_ASSIGN(CountDownLatch);
59};
60
61template <typename T> class SyncTestThread;
62
63template <typename T>
64static void launch_sync_test_thread(void *arg) {
65    SyncTestThread<T> *stt = static_cast<SyncTestThread<T>*>(arg);
66    stt->run();
67}
68
69extern "C" {
70   typedef void (*CB_THREAD_MAIN)(void *);
71}
72
73template <typename T>
74class SyncTestThread {
75public:
76
77    SyncTestThread(CountDownLatch *s, CountDownLatch *p, Generator<T> *testGen) :
78        startingLine(s), pistol(p), gen(testGen) {}
79
80    SyncTestThread(const SyncTestThread &other) :
81        startingLine(other.startingLine),
82        pistol(other.pistol),
83        gen(other.gen) {}
84
85    void start(void) {
86        if (cb_create_thread(&thread, (CB_THREAD_MAIN)( launch_sync_test_thread<T> ), this, 0) != 0) {
87            throw std::runtime_error("Error initializing thread");
88        }
89    }
90
91    void run(void) {
92        startingLine->decr();
93        pistol->wait();
94        result = (*gen)();
95    }
96
97    void join(void) {
98        if (cb_join_thread(thread) != 0) {
99            throw std::runtime_error("Failed to join.");
100        }
101    }
102
103    const T getResult(void) const { return result; };
104
105private:
106    CountDownLatch *startingLine;
107    CountDownLatch *pistol;
108    Generator<T>   *gen;
109
110    T         result;
111    cb_thread_t thread;
112};
113
114template <typename T>
115static void starter(SyncTestThread<T> &t) { t.start(); }
116
117template <typename T>
118static void waiter(SyncTestThread<T> &t) { t.join(); }
119
120template <typename T>
121std::vector<T> getCompletedThreads(size_t n, Generator<T> *gen) {
122    CountDownLatch startingLine(n), pistol(1);
123
124    SyncTestThread<T> proto(&startingLine, &pistol, gen);
125    std::vector<SyncTestThread<T> > threads(n, proto);
126    cb_assert(threads.size() == n);
127    std::for_each(threads.begin(), threads.end(), starter<T>);
128
129    startingLine.wait();
130    pistol.decr();
131
132    std::for_each(threads.begin(), threads.end(), waiter<T>);
133
134    std::vector<T> results;
135    typename std::vector<SyncTestThread<T> >::iterator it;
136    for (it = threads.begin(); it != threads.end(); ++it) {
137        results.push_back(it->getResult());
138    }
139
140    return results;
141}
142
143#endif  // TESTS_MODULE_TESTS_THREADTESTS_H_
144