1/* -*- Mode: C++; tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*- */ 2/* 3 * Copyright 2010 Couchbase, Inc 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18#include "config.h" 19 20#include <algorithm> 21#include <iostream> 22#include <vector> 23 24#include "locks.h" 25 26#ifndef TESTS_MODULE_TESTS_THREADTESTS_H_ 27#define TESTS_MODULE_TESTS_THREADTESTS_H_ 1 28 29template <typename T> 30class Generator { 31public: 32 virtual ~Generator() {} 33 virtual T operator()() = 0; 34}; 35 36class CountDownLatch { 37public: 38 39 CountDownLatch(int n=1) : count(n) {} 40 41 void decr(void) { 42 LockHolder lh(so); 43 --count; 44 so.notify(); 45 } 46 47 void wait(void) { 48 LockHolder lh(so); 49 while (count > 0) { 50 so.wait(); 51 } 52 } 53 54private: 55 int count; 56 SyncObject so; 57 58 DISALLOW_COPY_AND_ASSIGN(CountDownLatch); 59}; 60 61template <typename T> class SyncTestThread; 62 63template <typename T> 64static void launch_sync_test_thread(void *arg) { 65 SyncTestThread<T> *stt = static_cast<SyncTestThread<T>*>(arg); 66 stt->run(); 67} 68 69extern "C" { 70 typedef void (*CB_THREAD_MAIN)(void *); 71} 72 73template <typename T> 74class SyncTestThread { 75public: 76 77 SyncTestThread(CountDownLatch *s, CountDownLatch *p, Generator<T> *testGen) : 78 startingLine(s), pistol(p), gen(testGen) {} 79 80 SyncTestThread(const SyncTestThread &other) : 81 startingLine(other.startingLine), 82 pistol(other.pistol), 83 gen(other.gen) {} 84 85 void start(void) { 86 if (cb_create_thread(&thread, (CB_THREAD_MAIN)( launch_sync_test_thread<T> ), this, 0) != 0) { 87 throw std::runtime_error("Error initializing thread"); 88 } 89 } 90 91 void run(void) { 92 startingLine->decr(); 93 pistol->wait(); 94 result = (*gen)(); 95 } 96 97 void join(void) { 98 if (cb_join_thread(thread) != 0) { 99 throw std::runtime_error("Failed to join."); 100 } 101 } 102 103 const T getResult(void) const { return result; }; 104 105private: 106 CountDownLatch *startingLine; 107 CountDownLatch *pistol; 108 Generator<T> *gen; 109 110 T result; 111 cb_thread_t thread; 112}; 113 114template <typename T> 115static void starter(SyncTestThread<T> &t) { t.start(); } 116 117template <typename T> 118static void waiter(SyncTestThread<T> &t) { t.join(); } 119 120template <typename T> 121std::vector<T> getCompletedThreads(size_t n, Generator<T> *gen) { 122 CountDownLatch startingLine(n), pistol(1); 123 124 SyncTestThread<T> proto(&startingLine, &pistol, gen); 125 std::vector<SyncTestThread<T> > threads(n, proto); 126 cb_assert(threads.size() == n); 127 std::for_each(threads.begin(), threads.end(), starter<T>); 128 129 startingLine.wait(); 130 pistol.decr(); 131 132 std::for_each(threads.begin(), threads.end(), waiter<T>); 133 134 std::vector<T> results; 135 typename std::vector<SyncTestThread<T> >::iterator it; 136 for (it = threads.begin(); it != threads.end(); ++it) { 137 results.push_back(it->getResult()); 138 } 139 140 return results; 141} 142 143#endif // TESTS_MODULE_TESTS_THREADTESTS_H_ 144