LCOV - code coverage report
Current view: top level - src - ctpl_stl.h (source / functions) Hit Total Coverage
Test: total_coverage.info Lines: 72 96 75.0 %
Date: 2025-02-23 09:33:43 Functions: 25 31 80.6 %

          Line data    Source code
       1             : /*********************************************************
       2             : *
       3             : *  Copyright (C) 2014 by Vitaliy Vitsentiy
       4             : *
       5             : *  Licensed under the Apache License, Version 2.0 (the "License");
       6             : *  you may not use this file except in compliance with the License.
       7             : *  You may obtain a copy of the License at
       8             : *
       9             : *     http://www.apache.org/licenses/LICENSE-2.0
      10             : *
      11             : *  Unless required by applicable law or agreed to in writing, software
      12             : *  distributed under the License is distributed on an "AS IS" BASIS,
      13             : *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      14             : *  See the License for the specific language governing permissions and
      15             : *  limitations under the License.
      16             : *
      17             : *********************************************************/
      18             : 
      19             : 
      20             : #ifndef __ctpl_stl_thread_pool_H__
      21             : #define __ctpl_stl_thread_pool_H__
      22             : 
      23             : #include <functional>
      24             : #include <thread>
      25             : #include <atomic>
      26             : #include <vector>
      27             : #include <memory>
      28             : #include <exception>
      29             : #include <future>
      30             : #include <mutex>
      31             : #include <queue>
      32             : 
      33             : 
      34             : 
      35             : // thread pool to run user's functors with signature
      36             : //      ret func(int id, other_params)
      37             : // where id is the index of the thread that runs the functor
      38             : // ret is some return type
      39             : 
      40             : 
      41             : namespace ctpl {
      42             : 
      43             :     namespace detail {
      44             :         template <typename T>
      45           0 :         class Queue {
      46             :         public:
      47        3162 :             bool push(T const & value) {
      48        3162 :                 std::unique_lock<std::mutex> lock(this->mutex);
      49        3162 :                 this->q.push(value);
      50        6324 :                 return true;
      51             :             }
      52             :             // deletes the retrieved element, do not use for non integral types
      53       12299 :             bool pop(T & v) {
      54       12299 :                 std::unique_lock<std::mutex> lock(this->mutex);
      55       12299 :                 if (this->q.empty())
      56             :                     return false;
      57        3162 :                 v = this->q.front();
      58       15461 :                 this->q.pop();
      59             :                 return true;
      60             :             }
      61             :             bool empty() {
      62             :                 std::unique_lock<std::mutex> lock(this->mutex);
      63             :                 return this->q.empty();
      64             :             }
      65             :         private:
      66             :             std::queue<T> q;
      67             :             std::mutex mutex;
      68             :         };
      69             :     }
      70             : 
      71             :     class thread_pool {
      72             : 
      73             :     public:
      74             : 
      75         476 :         thread_pool() { this->init(); }
      76             :         explicit thread_pool(int nThreads) { this->init(); this->resize(nThreads); }
      77             : 
      78             :         // the destructor waits for all the functions in the queue to be finished
      79         476 :         ~thread_pool() {
      80         476 :             this->stop(true);
      81         476 :         }
      82             : 
      83             :         // get the number of running threads in the pool
      84        2454 :         int size() { return static_cast<int>(this->threads.size()); }
      85             : 
      86             :         // number of idle threads
      87             :         int n_idle() { return this->nWaiting; }
      88             :         std::thread & get_thread(int i) { return *this->threads[i]; }
      89             : 
      90             :         // change the number of threads in the pool
      91             :         // should be called from one thread, otherwise be careful to not interleave, also with this->stop()
      92             :         // nThreads must be >= 0
      93         348 :         void resize(int nThreads) {
      94         348 :             if (!this->isStop && !this->isDone) {
      95         348 :                 int oldNThreads = static_cast<int>(this->threads.size());
      96         348 :                 if (oldNThreads <= nThreads) {  // if the number of threads is increased
      97         348 :                     this->threads.resize(nThreads);
      98         348 :                     this->flags.resize(nThreads);
      99             : 
     100        1740 :                     for (int i = oldNThreads; i < nThreads; ++i) {
     101        1392 :                         this->flags[i] = std::make_shared<std::atomic<bool>>(false);
     102        1392 :                         this->set_thread(i);
     103             :                     }
     104             :                 }
     105             :                 else {  // the number of threads is decreased
     106           0 :                     for (int i = oldNThreads - 1; i >= nThreads; --i) {
     107           0 :                         *this->flags[i] = true;  // this thread will finish
     108           0 :                         this->threads[i]->detach();
     109             :                     }
     110           0 :                     {
     111             :                         // stop the detached threads that were waiting
     112           0 :                         std::unique_lock<std::mutex> lock(this->mutex);
     113           0 :                         this->cv.notify_all();
     114             :                     }
     115           0 :                     this->threads.resize(nThreads);  // safe to delete because the threads are detached
     116           0 :                     this->flags.resize(nThreads);  // safe to delete because the threads have copies of shared_ptr of the flags, not originals
     117             :                 }
     118             :             }
     119         348 :         }
     120             : 
     121             :         // empty the queue
     122        1310 :         void clear_queue() {
     123        1310 :             std::function<void(int id)> * _f;
     124        1310 :             while (this->q.pop(_f))
     125           0 :                 delete _f; // empty the queue
     126        1310 :         }
     127             : 
     128             :         // pops a functional wrapper to the original function
     129             :         std::function<void(int)> pop() {
     130             :             std::function<void(int id)> * _f = nullptr;
     131             :             this->q.pop(_f);
     132             :             std::unique_ptr<std::function<void(int id)>> func(_f); // at return, delete the function even if an exception occurred
     133             :             std::function<void(int)> f;
     134             :             if (_f)
     135             :                 f = *_f;
     136             :             return f;
     137             :         }
     138             : 
     139             :         // wait for all computing threads to finish and stop all threads
     140             :         // may be called asynchronously to not pause the calling thread while waiting
     141             :         // if isWait == true, all the functions in the queue are run, otherwise the queue is cleared without running the functions
     142        1310 :         void stop(bool isWait = false) {
     143        1310 :             if (!isWait) {
     144           0 :                 if (this->isStop)
     145             :                     return;
     146           0 :                 this->isStop = true;
     147           0 :                 for (int i = 0, n = this->size(); i < n; ++i) {
     148           0 :                     *this->flags[i] = true;  // command the threads to stop
     149             :                 }
     150           0 :                 this->clear_queue();  // empty the queue
     151             :             }
     152             :             else {
     153        1310 :                 if (this->isDone || this->isStop)
     154         834 :                     return;
     155         476 :                 this->isDone = true;  // give the waiting threads a command to finish
     156             :             }
     157         476 :             {
     158         476 :                 std::unique_lock<std::mutex> lock(this->mutex);
     159         476 :                 this->cv.notify_all();  // stop all waiting threads
     160             :             }
     161        1868 :             for (int i = 0; i < static_cast<int>(this->threads.size()); ++i) {  // wait for the computing threads to finish
     162        1392 :                 if (this->threads[i]->joinable())
     163        1392 :                     this->threads[i]->join();
     164             :             }
     165             :             // if there were no threads in the pool but some functors in the queue, the functors are not deleted by the threads
     166             :             // therefore delete them here
     167         476 :             this->clear_queue();
     168         476 :             this->threads.clear();
     169         476 :             this->flags.clear();
     170             :         }
     171             : 
     172             :         template<typename F, typename... Rest>
     173           0 :         auto push(F && f, Rest&&... rest) ->std::future<decltype(f(0, rest...))> {
     174           0 :             auto pck = std::make_shared<std::packaged_task<decltype(f(0, rest...))(int)>>(
     175             :                     std::bind(std::forward<F>(f), std::placeholders::_1, std::forward<Rest>(rest)...)
     176             :             );
     177           0 :             auto _f = new std::function<void(int id)>([pck](int id) {
     178           0 :                 (*pck)(id);
     179             :             });
     180           0 :             this->q.push(_f);
     181           0 :             std::unique_lock<std::mutex> lock(this->mutex);
     182           0 :             this->cv.notify_one();
     183           0 :             return pck->get_future();
     184             :         }
     185             : 
     186             :         // run the user's function that excepts argument int - id of the running thread. returned value is templatized
     187             :         // operator returns std::future, where the user can get the result and rethrow the caught exceptions
     188             :         template<typename F>
     189        2832 :         auto push(F && f) ->std::future<decltype(f(0))> {
     190        2832 :             auto pck = std::make_shared<std::packaged_task<decltype(f(0))(int)>>(std::forward<F>(f));
     191       14160 :             auto _f = new std::function<void(int id)>([pck](int id) {
     192        2832 :                 (*pck)(id);
     193             :             });
     194        2832 :             this->q.push(_f);
     195        5664 :             std::unique_lock<std::mutex> lock(this->mutex);
     196        2832 :             this->cv.notify_one();
     197        5664 :             return pck->get_future();
     198             :         }
     199             : 
     200             : 
     201             :     private:
     202             : 
     203             :         // deleted
     204             :         thread_pool(const thread_pool &);// = delete;
     205             :         thread_pool(thread_pool &&);// = delete;
     206             :         thread_pool & operator=(const thread_pool &);// = delete;
     207             :         thread_pool & operator=(thread_pool &&);// = delete;
     208             : 
     209        1392 :         void set_thread(int i) {
     210        1392 :             std::shared_ptr<std::atomic<bool>> flag(this->flags[i]); // a copy of the shared ptr to the flag
     211        2784 :             auto f = [this, i, flag/* a copy of the shared ptr to the flag */]() {
     212        1392 :                 std::atomic<bool> & _flag = *flag;
     213        1392 :                 std::function<void(int id)> * _f;
     214        1392 :                 bool isPop = this->q.pop(_f);
     215        5927 :                 while (true) {
     216        5927 :                     while (isPop) {  // if there is anything in the queue
     217        5664 :                         std::unique_ptr<std::function<void(int id)>> func(_f); // at return, delete the function even if an exception occurred
     218        2832 :                         (*_f)(i);
     219        2832 :                         if (_flag)
     220           0 :                             return;  // the thread is wanted to stop, return even if the queue is not empty yet
     221             :                         else
     222        2832 :                             isPop = this->q.pop(_f);
     223             :                     }
     224             :                     // the queue is empty here, wait for the next command
     225        4798 :                     std::unique_lock<std::mutex> lock(this->mutex);
     226        3095 :                     ++this->nWaiting;
     227        9530 :                     this->cv.wait(lock, [this, &_f, &isPop, &_flag](){ isPop = this->q.pop(_f); return isPop || this->isDone || _flag; });
     228        3095 :                     --this->nWaiting;
     229        3095 :                     if (!isPop)
     230        2784 :                         return;  // if the queue is empty and this->isDone == true or *flag then return
     231             :                 }
     232        2784 :             };
     233        2784 :             this->threads[i].reset(new std::thread(f)); // compiler may not support std::make_unique()
     234        1392 :         }
     235             : 
     236         476 :         void init() { this->nWaiting = 0; this->isStop = false; this->isDone = false; }
     237             : 
     238             :         std::vector<std::unique_ptr<std::thread>> threads;
     239             :         std::vector<std::shared_ptr<std::atomic<bool>>> flags;
     240             :         detail::Queue<std::function<void(int id)> *> q;
     241             :         std::atomic<bool> isDone;
     242             :         std::atomic<bool> isStop;
     243             :         std::atomic<int> nWaiting;  // how many threads are waiting
     244             : 
     245             :         std::mutex mutex;
     246             :         std::condition_variable cv;
     247             :     };
     248             : 
     249             : }
     250             : 
     251             : #endif // __ctpl_stl_thread_pool_H__

Generated by: LCOV version 1.14