1 /* -*- Mode: C++; tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /*
3  *     Copyright 2017 Couchbase, Inc
4  *
5  *   Licensed under the Apache License, Version 2.0 (the "License");
6  *   you may not use this file except in compliance with the License.
7  *   You may obtain a copy of the License at
8  *
9  *       http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *   Unless required by applicable law or agreed to in writing, software
12  *   distributed under the License is distributed on an "AS IS" BASIS,
13  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *   See the License for the specific language governing permissions and
15  *   limitations under the License.
16  */
17 
18 class Barrier {
19 public:
Barrier()20     Barrier()
21         : n_threads(0) {
22     }
23 
24     /**
25      * Create a Barrier.
26      *  @param n_threads Total number of threads to wait for.
27      */
Barrier(size_t n_threads_)28     Barrier(size_t n_threads_)
29         : n_threads(n_threads_) {
30     }
31 
32     /**
33      * Change the number of expected threads
34      * @param n_threads_ Total number of threads to wait for
35      */
reset(size_t n_threads_)36     void reset(size_t n_threads_) {
37         std::lock_guard<std::mutex> lh(m);
38         n_threads = n_threads_;
39     }
40 
41     /*
42      * Wait for n_threads to invoke this function
43      *
44      * if the calling thread is the last one up, notify_all
45      * if the calling thread is not the last one up, wait (in the function)
46      *
47      * @param cb Callback to invoke under mutual exclusion once the total
48      *           number of threads has been reached.
49      */
50     template <typename Callback = void(void)>
wait(Callback cb = do_nothing)51     void wait(Callback cb = do_nothing) {
52         std::unique_lock<std::mutex> lh(m);
53         const size_t threshold = go + 1;
54 
55         if (++thread_count != n_threads) {
56             cv.wait(lh, [this, threshold](){
57                 return go >= threshold;
58             });
59         } else {
60             ++go;
61             thread_count = 0;
62             cb();
63             cv.notify_all(); // all threads accounted for, begin
64         }
65     }
66 
67 private:
do_nothing()68     static void do_nothing() {}
69 
70     size_t n_threads;
71     size_t thread_count {0};
72     size_t go {0};
73     std::mutex m;
74     std::condition_variable cv;
75 };
76