From 46498b6bbf6fee1a1e5e852c8e64f949f7b6cd6b Mon Sep 17 00:00:00 2001 From: James Yonan Date: Sat, 25 Apr 2015 10:03:33 -0600 Subject: [PATCH] Refactored RunContext to eliminate possibility of race between set_thread and set_server. --- openvpn/common/runcontext.hpp | 126 ++++++++++++++++------------------ 1 file changed, 61 insertions(+), 65 deletions(-) diff --git a/openvpn/common/runcontext.hpp b/openvpn/common/runcontext.hpp index 653c3e10..59aa0a9c 100644 --- a/openvpn/common/runcontext.hpp +++ b/openvpn/common/runcontext.hpp @@ -19,6 +19,15 @@ // along with this program in the COPYING file. // If not, see . +// Manage a pool of threads for a multi-threaded server. +// +// To stress test this code, in client after serv->start() add: +// if (unit == 3 || unit == 5) +// throw Exception("HIT IT"); +// And after "case PThreadBarrier::ERROR:" +// if (unit & 1) +// break; + #ifndef OPENVPN_COMMON_RUNCONTEXT_H #define OPENVPN_COMMON_RUNCONTEXT_H @@ -46,29 +55,6 @@ namespace openvpn { template class RunContext : public LogBase { - struct Thread - { - Thread() : thread(NULL) {} - - Thread(Thread&& ref) noexcept - : thread(ref.thread), - serv(std::move(ref.serv)) - { - static_assert(std::is_nothrow_move_constructible::value, "class RunContext::Thread not noexcept move constructable"); - ref.thread = NULL; - } - - Thread(ThreadType* thread_arg) : thread(thread_arg) {} - - ~Thread() { delete thread; } - - Thread(const Thread&) = delete; - Thread& operator=(const Thread&) = delete; - - ThreadType* thread; - typename ServerThread::Ptr serv; - }; - public: typedef boost::intrusive_ptr Ptr; @@ -93,11 +79,10 @@ namespace openvpn { RunContext() : io_service(1), exit_timer(io_service), - threads_added(0), - threads_removed(0), + thread_count(0), + halt(false), log_context(this), - log_wrap(), - halt(false) + log_wrap() { signals.reset(new ASIOSignals(io_service)); signal_rearm(); @@ -110,37 +95,52 @@ namespace openvpn { void set_thread(const unsigned int unit, ThreadType* thread) { - if (unit != threads.size()) - throw Exception("RunContext::set_thread: unexpected unit number"); - threads.emplace_back(thread); + while (threadlist.size() <= unit) + threadlist.push_back(NULL); + if (threadlist[unit]) + throw Exception("RunContext::set_thread: overwrite"); + threadlist[unit] = thread; } // called from worker thread - void set_server(const unsigned int unit, const typename ServerThread::Ptr& serv) + void set_server(const unsigned int unit, ServerThread* serv) { Mutex::scoped_lock lock(mutex); - threads[unit].serv = serv; + if (halt) + throw Exception("RunContext::set_server: halting"); + while (servlist.size() <= unit) + servlist.push_back(NULL); + if (servlist[unit]) + throw Exception("RunContext::set_server: overwrite"); + servlist[unit] = serv; } // called from worker thread void clear_server(const unsigned int unit) { Mutex::scoped_lock lock(mutex); - threads[unit].serv.reset(); + if (unit < servlist.size()) + servlist[unit] = NULL; } void run() { if (!halt) - { - io_service.run(); - } + io_service.run(); } void join() { - for (size_t i = 0; i < threads.size(); ++i) - threads[i].thread->join(); + for (size_t i = 0; i < threadlist.size(); ++i) + { + ThreadType* t = threadlist[i]; + if (t) + { + t->join(); + delete t; + threadlist[i] = NULL; + } + } } virtual void log(const std::string& str) @@ -161,25 +161,18 @@ namespace openvpn { } private: - // called from worker thread + // called from main or worker thread void add_thread() { Mutex::scoped_lock lock(mutex); - ++threads_added; + ++thread_count; } - // called from worker thread + // called from main or worker thread void remove_thread() { Mutex::scoped_lock lock(mutex); - ++threads_removed; - test_completion(); - } - - void test_completion() - { - const size_t s = threads.size(); - if (threads_added == s && threads_removed == s) + if (--thread_count <= 0) do_cancel(); } @@ -189,7 +182,7 @@ namespace openvpn { do_cancel(); } - // may be called from worker thread + // called from main or worker thread void do_cancel() { if (!halt) @@ -202,26 +195,24 @@ namespace openvpn { io_service.post(asio_dispatch_post(&ASIOSignals::cancel, signals.get())); unsigned int stopped = 0; - for (size_t i = 0; i < threads.size(); ++i) + for (size_t i = 0; i < servlist.size(); ++i) { - Thread& thr = threads[i]; - if (thr.serv) + ServerThread* serv = servlist[i]; + if (serv) { - thr.serv->thread_safe_stop(); + serv->thread_safe_stop(); ++stopped; } - thr.serv.reset(); + servlist[i] = NULL; } - OPENVPN_LOG("Stopping " << stopped << '/' << threads.size() << " thread(s)"); + OPENVPN_LOG("Stopping " << stopped << '/' << servlist.size() << " thread(s)"); } } void exit_timer_callback(const boost::system::error_code& e) { if (!e && !halt) - { - cancel(); - } + cancel(); } void signal(const boost::system::error_code& error, int signum) @@ -250,18 +241,23 @@ namespace openvpn { signals->register_signals_all(asio_dispatch_signal(&RunContext::signal, this)); } + // these vars only used by main thread boost::asio::io_service io_service; typename Stats::Ptr stats; ASIOSignals::Ptr signals; AsioTimer exit_timer; - std::vector threads; - unsigned int threads_added; - unsigned int threads_removed; + std::vector threadlist; + + // servlist and related vars protected by mutex + Mutex mutex; + std::vector servlist; + int thread_count; + volatile bool halt; + + // logging protected by log_mutex + Mutex log_mutex; Log::Context log_context; Log::Context::Wrapper log_wrap; // must be constructed after log_context - Mutex mutex; - Mutex log_mutex; - bool halt; }; }