mirror of
https://github.com/OpenVPN/openvpn3.git
synced 2024-09-20 20:13:05 +02:00
Refactored RunContext to eliminate possibility of race between
set_thread and set_server.
This commit is contained in:
parent
709486cd1a
commit
46498b6bbf
@ -19,6 +19,15 @@
|
||||
// along with this program in the COPYING file.
|
||||
// If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
// Manage a pool of threads for a multi-threaded server.
|
||||
//
|
||||
// To stress test this code, in client after serv->start() add:
|
||||
// if (unit == 3 || unit == 5)
|
||||
// throw Exception("HIT IT");
|
||||
// And after "case PThreadBarrier::ERROR:"
|
||||
// if (unit & 1)
|
||||
// break;
|
||||
|
||||
#ifndef OPENVPN_COMMON_RUNCONTEXT_H
|
||||
#define OPENVPN_COMMON_RUNCONTEXT_H
|
||||
|
||||
@ -46,29 +55,6 @@ namespace openvpn {
|
||||
template <typename ServerThread, typename Stats>
|
||||
class RunContext : public LogBase
|
||||
{
|
||||
struct Thread
|
||||
{
|
||||
Thread() : thread(NULL) {}
|
||||
|
||||
Thread(Thread&& ref) noexcept
|
||||
: thread(ref.thread),
|
||||
serv(std::move(ref.serv))
|
||||
{
|
||||
static_assert(std::is_nothrow_move_constructible<Thread>::value, "class RunContext::Thread not noexcept move constructable");
|
||||
ref.thread = NULL;
|
||||
}
|
||||
|
||||
Thread(ThreadType* thread_arg) : thread(thread_arg) {}
|
||||
|
||||
~Thread() { delete thread; }
|
||||
|
||||
Thread(const Thread&) = delete;
|
||||
Thread& operator=(const Thread&) = delete;
|
||||
|
||||
ThreadType* thread;
|
||||
typename ServerThread::Ptr serv;
|
||||
};
|
||||
|
||||
public:
|
||||
typedef boost::intrusive_ptr<RunContext> Ptr;
|
||||
|
||||
@ -93,11 +79,10 @@ namespace openvpn {
|
||||
RunContext()
|
||||
: io_service(1),
|
||||
exit_timer(io_service),
|
||||
threads_added(0),
|
||||
threads_removed(0),
|
||||
thread_count(0),
|
||||
halt(false),
|
||||
log_context(this),
|
||||
log_wrap(),
|
||||
halt(false)
|
||||
log_wrap()
|
||||
{
|
||||
signals.reset(new ASIOSignals(io_service));
|
||||
signal_rearm();
|
||||
@ -110,37 +95,52 @@ namespace openvpn {
|
||||
|
||||
void set_thread(const unsigned int unit, ThreadType* thread)
|
||||
{
|
||||
if (unit != threads.size())
|
||||
throw Exception("RunContext::set_thread: unexpected unit number");
|
||||
threads.emplace_back(thread);
|
||||
while (threadlist.size() <= unit)
|
||||
threadlist.push_back(NULL);
|
||||
if (threadlist[unit])
|
||||
throw Exception("RunContext::set_thread: overwrite");
|
||||
threadlist[unit] = thread;
|
||||
}
|
||||
|
||||
// called from worker thread
|
||||
void set_server(const unsigned int unit, const typename ServerThread::Ptr& serv)
|
||||
void set_server(const unsigned int unit, ServerThread* serv)
|
||||
{
|
||||
Mutex::scoped_lock lock(mutex);
|
||||
threads[unit].serv = serv;
|
||||
if (halt)
|
||||
throw Exception("RunContext::set_server: halting");
|
||||
while (servlist.size() <= unit)
|
||||
servlist.push_back(NULL);
|
||||
if (servlist[unit])
|
||||
throw Exception("RunContext::set_server: overwrite");
|
||||
servlist[unit] = serv;
|
||||
}
|
||||
|
||||
// called from worker thread
|
||||
void clear_server(const unsigned int unit)
|
||||
{
|
||||
Mutex::scoped_lock lock(mutex);
|
||||
threads[unit].serv.reset();
|
||||
if (unit < servlist.size())
|
||||
servlist[unit] = NULL;
|
||||
}
|
||||
|
||||
void run()
|
||||
{
|
||||
if (!halt)
|
||||
{
|
||||
io_service.run();
|
||||
}
|
||||
io_service.run();
|
||||
}
|
||||
|
||||
void join()
|
||||
{
|
||||
for (size_t i = 0; i < threads.size(); ++i)
|
||||
threads[i].thread->join();
|
||||
for (size_t i = 0; i < threadlist.size(); ++i)
|
||||
{
|
||||
ThreadType* t = threadlist[i];
|
||||
if (t)
|
||||
{
|
||||
t->join();
|
||||
delete t;
|
||||
threadlist[i] = NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual void log(const std::string& str)
|
||||
@ -161,25 +161,18 @@ namespace openvpn {
|
||||
}
|
||||
|
||||
private:
|
||||
// called from worker thread
|
||||
// called from main or worker thread
|
||||
void add_thread()
|
||||
{
|
||||
Mutex::scoped_lock lock(mutex);
|
||||
++threads_added;
|
||||
++thread_count;
|
||||
}
|
||||
|
||||
// called from worker thread
|
||||
// called from main or worker thread
|
||||
void remove_thread()
|
||||
{
|
||||
Mutex::scoped_lock lock(mutex);
|
||||
++threads_removed;
|
||||
test_completion();
|
||||
}
|
||||
|
||||
void test_completion()
|
||||
{
|
||||
const size_t s = threads.size();
|
||||
if (threads_added == s && threads_removed == s)
|
||||
if (--thread_count <= 0)
|
||||
do_cancel();
|
||||
}
|
||||
|
||||
@ -189,7 +182,7 @@ namespace openvpn {
|
||||
do_cancel();
|
||||
}
|
||||
|
||||
// may be called from worker thread
|
||||
// called from main or worker thread
|
||||
void do_cancel()
|
||||
{
|
||||
if (!halt)
|
||||
@ -202,26 +195,24 @@ namespace openvpn {
|
||||
io_service.post(asio_dispatch_post(&ASIOSignals::cancel, signals.get()));
|
||||
|
||||
unsigned int stopped = 0;
|
||||
for (size_t i = 0; i < threads.size(); ++i)
|
||||
for (size_t i = 0; i < servlist.size(); ++i)
|
||||
{
|
||||
Thread& thr = threads[i];
|
||||
if (thr.serv)
|
||||
ServerThread* serv = servlist[i];
|
||||
if (serv)
|
||||
{
|
||||
thr.serv->thread_safe_stop();
|
||||
serv->thread_safe_stop();
|
||||
++stopped;
|
||||
}
|
||||
thr.serv.reset();
|
||||
servlist[i] = NULL;
|
||||
}
|
||||
OPENVPN_LOG("Stopping " << stopped << '/' << threads.size() << " thread(s)");
|
||||
OPENVPN_LOG("Stopping " << stopped << '/' << servlist.size() << " thread(s)");
|
||||
}
|
||||
}
|
||||
|
||||
void exit_timer_callback(const boost::system::error_code& e)
|
||||
{
|
||||
if (!e && !halt)
|
||||
{
|
||||
cancel();
|
||||
}
|
||||
cancel();
|
||||
}
|
||||
|
||||
void signal(const boost::system::error_code& error, int signum)
|
||||
@ -250,18 +241,23 @@ namespace openvpn {
|
||||
signals->register_signals_all(asio_dispatch_signal(&RunContext::signal, this));
|
||||
}
|
||||
|
||||
// these vars only used by main thread
|
||||
boost::asio::io_service io_service;
|
||||
typename Stats::Ptr stats;
|
||||
ASIOSignals::Ptr signals;
|
||||
AsioTimer exit_timer;
|
||||
std::vector<Thread> threads;
|
||||
unsigned int threads_added;
|
||||
unsigned int threads_removed;
|
||||
std::vector<ThreadType*> threadlist;
|
||||
|
||||
// servlist and related vars protected by mutex
|
||||
Mutex mutex;
|
||||
std::vector<ServerThread*> servlist;
|
||||
int thread_count;
|
||||
volatile bool halt;
|
||||
|
||||
// logging protected by log_mutex
|
||||
Mutex log_mutex;
|
||||
Log::Context log_context;
|
||||
Log::Context::Wrapper log_wrap; // must be constructed after log_context
|
||||
Mutex mutex;
|
||||
Mutex log_mutex;
|
||||
bool halt;
|
||||
};
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user