1 /* -*- Mode: C++; tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /*
3  *     Copyright 2017 Couchbase, Inc
4  *
5  *   Licensed under the Apache License, Version 2.0 (the "License");
6  *   you may not use this file except in compliance with the License.
7  *   You may obtain a copy of the License at
8  *
9  *       http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *   Unless required by applicable law or agreed to in writing, software
12  *   distributed under the License is distributed on an "AS IS" BASIS,
13  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *   See the License for the specific language governing permissions and
15  *   limitations under the License.
16  */
17 
18 #pragma once
19 
20 #include <condition_variable>
21 
22 /** Object which is used to synchronise the execution of a number of threads.
23  *  Each thread calls threadUp(), and until all threads have called this
24  *  they are all blocked.
25  */
26 class ThreadGate {
27 public:
ThreadGate()28     ThreadGate() : n_threads(0) {
29     }
30 
31     /** Create a ThreadGate.
32      *  @param n_threads Total number of threads to wait for.
33      */
ThreadGate(size_t n_threads_)34     ThreadGate(size_t n_threads_) : n_threads(n_threads_) {
35     }
36 
37     /*
38      * atomically increment a threadCount
39      * if the calling thread is the last one up, notify_all
40      * if the calling thread is not the last one up, wait (in the function)
41      */
threadUp()42     void threadUp() {
43         std::unique_lock<std::mutex> lh(m);
44         if (++thread_count != n_threads) {
45             cv.wait(lh, [this, &lh]() { return isComplete(lh); });
46         } else {
47             cv.notify_all(); // all threads accounted for, begin
48         }
49     }
50 
51     template <typename Rep, typename Period>
waitFor(std::chrono::duration<Rep, Period> timeout)52     void waitFor(std::chrono::duration<Rep, Period> timeout) {
53         std::unique_lock<std::mutex> lh(m);
54         cv.wait_for(lh, timeout, [this, &lh]() { return isComplete(lh); });
55     }
56 
getCount()57     size_t getCount() {
58         std::unique_lock<std::mutex> lh(m);
59         return thread_count;
60     }
61 
isComplete()62     bool isComplete() {
63         std::unique_lock<std::mutex> lh(m);
64         return isComplete(lh);
65     }
66 
67 private:
isComplete(const std::unique_lock<std::mutex>&)68     bool isComplete(const std::unique_lock<std::mutex>&) {
69         return thread_count == n_threads;
70     }
71 
72     const size_t n_threads;
73     size_t thread_count{0};
74     std::mutex m;
75     std::condition_variable cv;
76 };
77