0
0
mirror of https://github.com/OpenVPN/openvpn3.git synced 2024-09-20 04:02:15 +02:00

Refactor ClientProto::Session to use ProtoContext as field insatead of Base

Currently the protocontext is used as kind of composition but not really
and makes following the code harder, since this inheritance not only serves
for composition but also as callbacks through virtual method inheritance.

Making ProtoContext a normal field and definining a callback interface makes
the class relationship easier to understand.

Signed-off-by: Arne Schwabe <arne@openvpn.net>
This commit is contained in:
Arne Schwabe 2024-02-22 07:26:46 +01:00 committed by Jenkins-dev
parent 2ff8029eba
commit e14c3f0441
5 changed files with 282 additions and 249 deletions

View File

@ -1164,7 +1164,7 @@ class ClientOptions : public RC<thread_unsafe_refcount>
// Copy ProtoConfig so that modifications due to server push will // Copy ProtoConfig so that modifications due to server push will
// not persist across client instantiations. // not persist across client instantiations.
cli_config->proto_context_config.reset(new Client::ProtoConfig(proto_config_cached(relay_mode))); cli_config->proto_context_config.reset(new ProtoContext::ProtoConfig(proto_config_cached(relay_mode)));
cli_config->proto_context_options = proto_context_options; cli_config->proto_context_options = proto_context_options;
cli_config->push_base = push_base; cli_config->push_base = push_base;
@ -1274,7 +1274,7 @@ class ClientOptions : public RC<thread_unsafe_refcount>
} }
private: private:
Client::ProtoConfig &proto_config_cached(const bool relay_mode) ProtoContext::ProtoConfig &proto_config_cached(const bool relay_mode)
{ {
if (relay_mode && cp_relay) if (relay_mode && cp_relay)
return *cp_relay; return *cp_relay;
@ -1282,14 +1282,14 @@ class ClientOptions : public RC<thread_unsafe_refcount>
return *cp_main; return *cp_main;
} }
Client::ProtoConfig::Ptr proto_config(const OptionList &opt, ProtoContext::ProtoConfig::Ptr proto_config(const OptionList &opt,
const Config &config, const Config &config,
const ParseClientConfig &pcc, const ParseClientConfig &pcc,
const bool relay_mode) const bool relay_mode)
{ {
// relay mode is null unless one of the below directives is defined // relay mode is null unless one of the below directives is defined
if (relay_mode && !opt.exists("relay-mode")) if (relay_mode && !opt.exists("relay-mode"))
return Client::ProtoConfig::Ptr(); return ProtoContext::ProtoConfig::Ptr();
// load flags // load flags
unsigned int lflags = SSLConfigAPI::LF_PARSE_MODE; unsigned int lflags = SSLConfigAPI::LF_PARSE_MODE;
@ -1314,7 +1314,7 @@ class ClientOptions : public RC<thread_unsafe_refcount>
cc->set_tls_ciphersuite_list(config.clientconf.tlsCiphersuitesList); cc->set_tls_ciphersuite_list(config.clientconf.tlsCiphersuitesList);
// client ProtoContext config // client ProtoContext config
Client::ProtoConfig::Ptr cp(new Client::ProtoConfig()); ProtoContext::ProtoConfig::Ptr cp(new ProtoContext::ProtoConfig());
cp->ssl_factory = cc->new_factory(); cp->ssl_factory = cc->new_factory();
cp->relay_mode = relay_mode; cp->relay_mode = relay_mode;
cp->dc.set_factory(new CryptoDCSelect<SSLLib::CryptoAPI>(cp->ssl_factory->libctx(), frame, cli_stats, rng)); cp->dc.set_factory(new CryptoDCSelect<SSLLib::CryptoAPI>(cp->ssl_factory->libctx(), frame, cli_stats, rng));
@ -1454,8 +1454,8 @@ class ClientOptions : public RC<thread_unsafe_refcount>
RandomAPI::Ptr prng; RandomAPI::Ptr prng;
Frame::Ptr frame; Frame::Ptr frame;
Layer layer; Layer layer;
Client::ProtoConfig::Ptr cp_main; ProtoContext::ProtoConfig::Ptr cp_main;
Client::ProtoConfig::Ptr cp_relay; ProtoContext::ProtoConfig::Ptr cp_relay;
RemoteList::Ptr remote_list; RemoteList::Ptr remote_list;
bool server_addr_float; bool server_addr_float;
TransportClientFactory::Ptr transport_factory; TransportClientFactory::Ptr transport_factory;

View File

@ -92,20 +92,13 @@ struct NotifyCallback
} }
}; };
class Session : ProtoContext, class Session : ProtoContextCallbackInterface,
TransportClientParent, TransportClientParent,
TunClientParent, TunClientParent,
public RC<thread_unsafe_refcount> public RC<thread_unsafe_refcount>
{ {
typedef ProtoContext Base;
typedef Base::PacketType PacketType;
using Base::now;
using Base::stat;
public: public:
typedef RCPtr<Session> Ptr; typedef RCPtr<Session> Ptr;
typedef Base::ProtoConfig ProtoConfig;
OPENVPN_EXCEPTION(client_exception); OPENVPN_EXCEPTION(client_exception);
OPENVPN_EXCEPTION(client_halt_restart); OPENVPN_EXCEPTION(client_halt_restart);
@ -133,7 +126,7 @@ class Session : ProtoContext,
{ {
} }
ProtoConfig::Ptr proto_context_config; ProtoContext::ProtoConfig::Ptr proto_context_config;
ProtoContextCompressionOptions::Ptr proto_context_options; ProtoContextCompressionOptions::Ptr proto_context_options;
PushOptionsBase::Ptr push_base; PushOptionsBase::Ptr push_base;
TransportClientFactory::Ptr transport_factory; TransportClientFactory::Ptr transport_factory;
@ -152,7 +145,7 @@ class Session : ProtoContext,
Session(openvpn_io::io_context &io_context_arg, Session(openvpn_io::io_context &io_context_arg,
const Config &config, const Config &config,
NotifyCallback *notify_callback_arg) NotifyCallback *notify_callback_arg)
: Base(config.proto_context_config, config.cli_stats), : proto_context(this, config.proto_context_config, config.cli_stats),
io_context(io_context_arg), io_context(io_context_arg),
transport_factory(config.transport_factory), transport_factory(config.transport_factory),
tun_factory(config.tun_factory), tun_factory(config.tun_factory),
@ -178,9 +171,9 @@ class Session : ProtoContext,
if (!packet_log) if (!packet_log)
OPENVPN_THROW(open_file_error, "cannot open packet log for output: " << OPENVPN_PACKET_LOG); OPENVPN_THROW(open_file_error, "cannot open packet log for output: " << OPENVPN_PACKET_LOG);
#endif #endif
Base::update_now(); proto_context.update_now();
Base::reset(); proto_context.reset();
// Base::enable_strict_openvpn_2x(); // proto_context.enable_strict_openvpn_2x();
info_hold.reset(new std::vector<ClientEvent::Base::Ptr>()); info_hold.reset(new std::vector<ClientEvent::Base::Ptr>());
} }
@ -194,7 +187,7 @@ class Session : ProtoContext,
{ {
if (!halt) if (!halt)
{ {
Base::update_now(); proto_context.update_now();
// coarse wakeup range // coarse wakeup range
housekeeping_schedule.init(Time::Duration::binary_ms(512), Time::Duration::binary_ms(1024)); housekeeping_schedule.init(Time::Duration::binary_ms(512), Time::Duration::binary_ms(1024));
@ -224,7 +217,7 @@ class Session : ProtoContext,
void send_explicit_exit_notify() void send_explicit_exit_notify()
{ {
if (!halt) if (!halt)
Base::send_explicit_exit_notify(); proto_context.send_explicit_exit_notify();
} }
void tun_set_disconnect() void tun_set_disconnect()
@ -235,22 +228,22 @@ class Session : ProtoContext,
void post_cc_msg(const std::string &msg) void post_cc_msg(const std::string &msg)
{ {
Base::update_now(); proto_context.update_now();
Base::write_control_string(msg); proto_context.write_control_string(msg);
Base::flush(true); proto_context.flush(true);
set_housekeeping_timer(); set_housekeeping_timer();
} }
void post_app_control_message(const std::string proto, const std::string message) void post_app_control_message(const std::string proto, const std::string message)
{ {
if (!conf().app_control_config.supports_protocol(proto)) if (!proto_context.conf().app_control_config.supports_protocol(proto))
{ {
ClientEvent::Base::Ptr ev = new ClientEvent::UnsupportedFeature{"missing acc protocol support", "server has not announced support of this custom app control protocol", false}; ClientEvent::Base::Ptr ev = new ClientEvent::UnsupportedFeature{"missing acc protocol support", "server has not announced support of this custom app control protocol", false};
cli_events->add_event(std::move(ev)); cli_events->add_event(std::move(ev));
return; return;
} }
for (auto fragment : conf().app_control_config.format_message(proto, message)) for (auto fragment : proto_context.conf().app_control_config.format_message(proto, message))
post_cc_msg(std::move(fragment)); post_cc_msg(std::move(fragment));
} }
@ -328,13 +321,13 @@ class Session : ProtoContext,
{ {
try try
{ {
OPENVPN_LOG_CLIPROTO("Transport RECV " << server_endpoint_render() << ' ' << Base::dump_packet(buf)); OPENVPN_LOG_CLIPROTO("Transport RECV " << server_endpoint_render() << ' ' << proto_context.dump_packet(buf));
// update current time // update current time
Base::update_now(); proto_context.update_now();
// update last packet received // update last packet received
stat().update_last_packet_received(now()); proto_context.stat().update_last_packet_received(proto_context.now());
// log connecting event (only on first packet received) // log connecting event (only on first packet received)
if (!first_packet_received_) if (!first_packet_received_)
@ -345,13 +338,13 @@ class Session : ProtoContext,
} }
// get packet type // get packet type
Base::PacketType pt = Base::packet_type(buf); ProtoContext::PacketType pt = proto_context.packet_type(buf);
// process packet // process packet
if (pt.is_data()) if (pt.is_data())
{ {
// data packet // data packet
Base::data_decrypt(pt, buf); proto_context.data_decrypt(pt, buf);
if (buf.size()) if (buf.size())
{ {
#ifdef OPENVPN_PACKET_LOG #ifdef OPENVPN_PACKET_LOG
@ -366,15 +359,15 @@ class Session : ProtoContext,
} }
// do a lightweight flush // do a lightweight flush
Base::flush(false); proto_context.flush(false);
} }
else if (pt.is_control()) else if (pt.is_control())
{ {
// control packet // control packet
Base::control_net_recv(pt, std::move(buf)); proto_context.control_net_recv(pt, std::move(buf));
// do a full flush // do a full flush
Base::flush(true); proto_context.flush(true);
} }
else else
cli_stats->error(Error::KEY_STATE_ERROR); cli_stats->error(Error::KEY_STATE_ERROR);
@ -412,7 +405,7 @@ class Session : ProtoContext,
OPENVPN_LOG_CLIPROTO("TUN recv, size=" << buf.size()); OPENVPN_LOG_CLIPROTO("TUN recv, size=" << buf.size());
// update current time // update current time
Base::update_now(); proto_context.update_now();
// log packet // log packet
#ifdef OPENVPN_PACKET_LOG #ifdef OPENVPN_PACKET_LOG
@ -432,7 +425,7 @@ class Session : ProtoContext,
// encrypt packet // encrypt packet
if (buf.size()) if (buf.size())
{ {
const ProtoContext::ProtoConfig &c = Base::conf(); const ProtoContext::ProtoConfig &c = proto_context.conf();
// when calculating mss, we take IPv4 and TCP headers into account // when calculating mss, we take IPv4 and TCP headers into account
// here we need to add it back since we check the whole IP packet size, not just TCP payload // here we need to add it back since we check the whole IP packet size, not just TCP payload
constexpr size_t MinTcpHeader = 20; constexpr size_t MinTcpHeader = 20;
@ -445,13 +438,13 @@ class Session : ProtoContext,
} }
else else
{ {
Base::data_encrypt(buf); proto_context.data_encrypt(buf);
if (buf.size()) if (buf.size())
{ {
// send packet via transport to destination // send packet via transport to destination
OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << Base::dump_packet(buf)); OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << proto_context.dump_packet(buf));
if (transport->transport_send(buf)) if (transport->transport_send(buf))
Base::update_last_sent(); proto_context.update_last_sent();
else if (halt) else if (halt)
return; return;
} }
@ -459,7 +452,7 @@ class Session : ProtoContext,
} }
// do a lightweight flush // do a lightweight flush
Base::flush(false); proto_context.flush(false);
// schedule housekeeping wakeup // schedule housekeeping wakeup
set_housekeeping_timer(); set_housekeeping_timer();
@ -473,7 +466,7 @@ class Session : ProtoContext,
// Return true if keepalive parameter(s) are enabled. // Return true if keepalive parameter(s) are enabled.
bool is_keepalive_enabled() const override bool is_keepalive_enabled() const override
{ {
return Base::is_keepalive_enabled(); return proto_context.is_keepalive_enabled();
} }
// Disable keepalive for rest of session, but fetch // Disable keepalive for rest of session, but fetch
@ -481,7 +474,7 @@ class Session : ProtoContext,
void disable_keepalive(unsigned int &keepalive_ping, void disable_keepalive(unsigned int &keepalive_ping,
unsigned int &keepalive_timeout) override unsigned int &keepalive_timeout) override
{ {
Base::disable_keepalive(keepalive_ping, keepalive_timeout); proto_context.disable_keepalive(keepalive_ping, keepalive_timeout);
} }
void transport_pre_resolve() override void transport_pre_resolve() override
@ -516,9 +509,9 @@ class Session : ProtoContext,
try try
{ {
OPENVPN_LOG("Connecting to " << server_endpoint_render()); OPENVPN_LOG("Connecting to " << server_endpoint_render());
Base::set_protocol(transport->transport_protocol()); proto_context.set_protocol(transport->transport_protocol());
Base::start(); proto_context.start();
Base::flush(true); proto_context.flush(true);
set_housekeeping_timer(); set_housekeeping_timer();
} }
catch (const std::exception &e) catch (const std::exception &e)
@ -587,7 +580,7 @@ class Session : ProtoContext,
OPENVPN_LOG("Session token: [redacted]"); OPENVPN_LOG("Session token: [redacted]");
#endif #endif
autologin_sessions = true; autologin_sessions = true;
conf().set_xmit_creds(true); proto_context.conf().set_xmit_creds(true);
creds->set_replace_password_with_session_id(true); creds->set_replace_password_with_session_id(true);
creds->set_session_id(username, sess_id); creds->set_session_id(username, sess_id);
} }
@ -688,9 +681,9 @@ class Session : ProtoContext,
// proto base class calls here for control channel network sends // proto base class calls here for control channel network sends
void control_net_send(const Buffer &net_buf) override void control_net_send(const Buffer &net_buf) override
{ {
OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << Base::dump_packet(net_buf)); OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << proto_context.dump_packet(net_buf));
if (transport->transport_send_const(net_buf)) if (transport->transport_send_const(net_buf))
Base::update_last_sent(); proto_context.update_last_sent();
} }
void recv_auth_failed(const std::string &msg) void recv_auth_failed(const std::string &msg)
@ -753,7 +746,7 @@ class Session : ProtoContext,
{ {
timeout = clamp_to_typerange<unsigned int>(std::stoul(timeout_str)); timeout = clamp_to_typerange<unsigned int>(std::stoul(timeout_str));
// Cap the timeout to end well before renegotiation starts // Cap the timeout to end well before renegotiation starts
timeout = std::min(timeout, static_cast<decltype(timeout)>(conf().renegotiate.to_seconds() / 2)); timeout = std::min(timeout, static_cast<decltype(timeout)>(proto_context.conf().renegotiate.to_seconds() / 2));
} }
catch (const std::logic_error &) catch (const std::logic_error &)
{ {
@ -776,7 +769,7 @@ class Session : ProtoContext,
void recv_relay() void recv_relay()
{ {
if (conf().relay_mode) if (proto_context.conf().relay_mode)
{ {
fatal_ = Error::RELAY; fatal_ = Error::RELAY;
fatal_reason_ = ""; fatal_reason_ = "";
@ -818,7 +811,7 @@ class Session : ProtoContext,
// proto base class calls here for app-level control-channel messages received // proto base class calls here for app-level control-channel messages received
void control_recv(BufferPtr &&app_bp) override void control_recv(BufferPtr &&app_bp) override
{ {
const std::string msg = Unicode::utf8_printable(Base::template read_control_string<std::string>(*app_bp), const std::string msg = Unicode::utf8_printable(ProtoContext::template read_control_string<std::string>(*app_bp),
Unicode::UTF8_FILTER | Unicode::UTF8_PASS_FMT); Unicode::UTF8_FILTER | Unicode::UTF8_PASS_FMT);
// OPENVPN_LOG("SERVER: " << sanitize_control_message(msg)); // OPENVPN_LOG("SERVER: " << sanitize_control_message(msg));
@ -859,13 +852,13 @@ class Session : ProtoContext,
void recv_custom_control_message(const std::string msg) void recv_custom_control_message(const std::string msg)
{ {
bool fullmessage = conf().app_control_recv.receive_message(msg); bool fullmessage = proto_context.conf().app_control_recv.receive_message(msg);
if (!fullmessage) if (!fullmessage)
return; return;
auto [proto, app_proto_msg] = conf().app_control_recv.get_message(); auto [proto, app_proto_msg] = proto_context.conf().app_control_recv.get_message();
if (conf().app_control_config.supports_protocol(proto)) if (proto_context.conf().app_control_config.supports_protocol(proto))
{ {
auto ev = new ClientEvent::AppCustomControlMessage(std::move(proto), std::move(app_proto_msg)); auto ev = new ClientEvent::AppCustomControlMessage(std::move(proto), std::move(app_proto_msg));
cli_events->add_event(std::move(ev)); cli_events->add_event(std::move(ev));
@ -897,7 +890,7 @@ class Session : ProtoContext,
<< render_options_sanitized(received_options, Option::RENDER_PASS_FMT | Option::RENDER_NUMBER | Option::RENDER_BRACKET)); << render_options_sanitized(received_options, Option::RENDER_PASS_FMT | Option::RENDER_NUMBER | Option::RENDER_BRACKET));
// relay servers are not allowed to establish a tunnel with us // relay servers are not allowed to establish a tunnel with us
if (Base::conf().relay_mode) if (proto_context.conf().relay_mode)
{ {
tun_error(Error::RELAY_ERROR, "tunnel not permitted to relay server"); tun_error(Error::RELAY_ERROR, "tunnel not permitted to relay server");
return; return;
@ -917,28 +910,28 @@ class Session : ProtoContext,
transport_factory->process_push(received_options); transport_factory->process_push(received_options);
// modify proto config (cipher, auth, key-derivation and compression methods) // modify proto config (cipher, auth, key-derivation and compression methods)
Base::process_push(received_options, *proto_context_options); proto_context.process_push(received_options, *proto_context_options);
// initialize tun/routing // initialize tun/routing
tun = tun_factory->new_tun_client_obj(io_context, *this, transport.get()); tun = tun_factory->new_tun_client_obj(io_context, *this, transport.get());
tun->tun_start(received_options, *transport, Base::dc_settings()); tun->tun_start(received_options, *transport, proto_context.dc_settings());
// we should be connected at this point // we should be connected at this point
if (!connected_) if (!connected_)
throw tun_exception("not connected"); throw tun_exception("not connected");
// Propagate tun-mtu back, it might have been overwritten by a pushed tun-mtu option // Propagate tun-mtu back, it might have been overwritten by a pushed tun-mtu option
conf().tun_mtu = tun->vpn_mtu(); proto_context.conf().tun_mtu = tun->vpn_mtu();
// initialize data channel after pushed options have been processed // initialize data channel after pushed options have been processed
Base::init_data_channel(); proto_context.init_data_channel();
// we got pushed options and initializated crypto - now we can push mss to dco // we got pushed options and initializated crypto - now we can push mss to dco
tun->adjust_mss(conf().mss_fix); tun->adjust_mss(proto_context.conf().mss_fix);
// Allow ProtoContext to suggest an alignment adjustment // Allow ProtoContext to suggest an alignment adjustment
// hint for transport layer. // hint for transport layer.
transport->reset_align_adjust(Base::align_adjust_hint()); transport->reset_align_adjust(proto_context.align_adjust_hint());
// process "inactive" directive // process "inactive" directive
process_inactive(received_options); process_inactive(received_options);
@ -954,10 +947,10 @@ class Session : ProtoContext,
cli_events->add_event(connected_); cli_events->add_event(connected_);
// send an event for custom app control if present // send an event for custom app control if present
if (!conf().app_control_config.supported_protocols.empty()) if (!proto_context.conf().app_control_config.supported_protocols.empty())
{ {
// Signal support for supported protocols // Signal support for supported protocols
auto ev = new ClientEvent::AppCustomControlMessage("internal:supported_protocols", string::join(conf().app_control_config.supported_protocols, ":")); auto ev = new ClientEvent::AppCustomControlMessage("internal:supported_protocols", string::join(proto_context.conf().app_control_config.supported_protocols, ":"));
cli_events->add_event(std::move(ev)); cli_events->add_event(std::move(ev));
} }
@ -1050,27 +1043,27 @@ class Session : ProtoContext,
void client_auth(Buffer &buf) override void client_auth(Buffer &buf) override
{ {
// we never send creds to a relay server // we never send creds to a relay server
if (creds && !Base::conf().relay_mode) if (creds && !proto_context.conf().relay_mode)
{ {
OPENVPN_LOG("Creds: " << creds->auth_info()); OPENVPN_LOG("Creds: " << creds->auth_info());
Base::write_auth_string(creds->get_username(), buf); proto_context.write_auth_string(creds->get_username(), buf);
#ifdef OPENVPN_DISABLE_AUTH_TOKEN // debugging only #ifdef OPENVPN_DISABLE_AUTH_TOKEN // debugging only
if (creds->session_id_defined()) if (creds->session_id_defined())
{ {
OPENVPN_LOG("NOTE: not sending auth-token"); OPENVPN_LOG("NOTE: not sending auth-token");
Base::write_empty_string(buf); ProtoContext::write_empty_string(buf);
} }
else else
#endif #endif
{ {
Base::write_auth_string(creds->get_password(), buf); proto_context.write_auth_string(creds->get_password(), buf);
} }
} }
else else
{ {
OPENVPN_LOG("Creds: None"); OPENVPN_LOG("Creds: None");
Base::write_empty_string(buf); // username write_empty_string(buf); // username
Base::write_empty_string(buf); // password write_empty_string(buf); // password
} }
} }
@ -1081,7 +1074,7 @@ class Session : ProtoContext,
{ {
if (!e && !halt && !received_options.partial()) if (!e && !halt && !received_options.partial())
{ {
Base::update_now(); proto_context.update_now();
if (!sent_push_request) if (!sent_push_request)
{ {
ClientEvent::Base::Ptr ev = new ClientEvent::GetConfig(); ClientEvent::Base::Ptr ev = new ClientEvent::GetConfig();
@ -1089,8 +1082,8 @@ class Session : ProtoContext,
sent_push_request = true; sent_push_request = true;
} }
OPENVPN_LOG("Sending PUSH_REQUEST to server..."); OPENVPN_LOG("Sending PUSH_REQUEST to server...");
Base::write_control_string(std::string("PUSH_REQUEST")); proto_context.write_control_string(std::string("PUSH_REQUEST"));
Base::flush(true); proto_context.flush(true);
set_housekeeping_timer(); set_housekeeping_timer();
{ {
@ -1134,7 +1127,7 @@ class Session : ProtoContext,
// react to any tls warning triggered during the tls-handshake // react to any tls warning triggered during the tls-handshake
virtual void check_tls_warnings() virtual void check_tls_warnings()
{ {
uint32_t tls_warnings = get_tls_warnings(); uint32_t tls_warnings = proto_context.get_tls_warnings();
if (tls_warnings & SSLAPI::TLS_WARN_SIG_MD5) if (tls_warnings & SSLAPI::TLS_WARN_SIG_MD5)
{ {
@ -1151,13 +1144,13 @@ class Session : ProtoContext,
void check_proto_warnings() void check_proto_warnings()
{ {
if (uses_bs64_cipher()) if (proto_context.uses_bs64_cipher())
{ {
ClientEvent::Base::Ptr ev = new ClientEvent::Warn("Proto: Using a 64-bit block cipher that is vulnerable to the SWEET32 attack. Please inform your admin to upgrade to a stronger algorithm. Support for 64-bit block cipher will be dropped in the future."); ClientEvent::Base::Ptr ev = new ClientEvent::Warn("Proto: Using a 64-bit block cipher that is vulnerable to the SWEET32 attack. Please inform your admin to upgrade to a stronger algorithm. Support for 64-bit block cipher will be dropped in the future.");
cli_events->add_event(std::move(ev)); cli_events->add_event(std::move(ev));
} }
CompressContext::Type comp_type = Base::conf().comp_ctx.type(); CompressContext::Type comp_type = proto_context.conf().comp_ctx.type();
// abort connection if compression is pushed and its support is unannounced // abort connection if compression is pushed and its support is unannounced
if (comp_type != CompressContext::COMP_STUBv2 if (comp_type != CompressContext::COMP_STUBv2
@ -1203,15 +1196,15 @@ class Session : ProtoContext,
if (!e && !halt) if (!e && !halt)
{ {
// update current time // update current time
Base::update_now(); proto_context.update_now();
housekeeping_schedule.reset(); housekeeping_schedule.reset();
Base::housekeeping(); proto_context.housekeeping();
if (Base::invalidated()) if (proto_context.invalidated())
{ {
if (notify_callback) if (notify_callback)
{ {
OPENVPN_LOG("Session invalidated: " << Error::name(Base::invalidation_reason())); OPENVPN_LOG("Session invalidated: " << Error::name(proto_context.invalidation_reason()));
stop(true); stop(true);
} }
else else
@ -1231,12 +1224,12 @@ class Session : ProtoContext,
if (halt) if (halt)
return; return;
Time next = Base::next_housekeeping(); Time next = proto_context.next_housekeeping();
if (!housekeeping_schedule.similar(next)) if (!housekeeping_schedule.similar(next))
{ {
if (!next.is_infinite()) if (!next.is_infinite())
{ {
next.max(now()); next.max(proto_context.now());
housekeeping_schedule.reset(next); housekeeping_schedule.reset(next);
housekeeping_timer.expires_at(next); housekeeping_timer.expires_at(next);
housekeeping_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error) housekeeping_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error)
@ -1377,7 +1370,7 @@ class Session : ProtoContext,
void schedule_info_hold_callback() void schedule_info_hold_callback()
{ {
Base::update_now(); proto_context.update_now();
info_hold_timer.expires_after(Time::Duration::seconds(1)); info_hold_timer.expires_after(Time::Duration::seconds(1));
info_hold_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error) info_hold_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error)
{ {
@ -1391,7 +1384,7 @@ class Session : ProtoContext,
{ {
if (!e && !halt) if (!e && !halt)
{ {
Base::update_now(); proto_context.update_now();
if (info_hold) if (info_hold)
{ {
for (auto &ev : *info_hold) for (auto &ev : *info_hold)
@ -1420,6 +1413,8 @@ class Session : ProtoContext,
} }
#endif #endif
ProtoContext proto_context;
openvpn_io::io_context &io_context; openvpn_io::io_context &io_context;
TransportClientFactory::Ptr transport_factory; TransportClientFactory::Ptr transport_factory;

View File

@ -109,16 +109,13 @@ class ServerProto
}; };
// This is the main server-side client instance object // This is the main server-side client instance object
class Session : ProtoContext, // OpenVPN protocol implementation class Session : ProtoContextCallbackInterface, // Callback interface from protocol implementation
public TransportLink, // Transport layer public TransportLink, // Transport layer
public TunLink, // Tun/routing layer public TunLink, // Tun/routing layer
public ManLink // Management layer public ManLink // Management layer
{ {
friend class Factory; // calls constructor friend class Factory; // calls constructor
using ProtoContext::now;
using ProtoContext::stat;
public: public:
typedef RCPtr<Session> Ptr; typedef RCPtr<Session> Ptr;
@ -142,11 +139,11 @@ class ServerProto
peer_addr = addr; peer_addr = addr;
// init OpenVPN protocol handshake // init OpenVPN protocol handshake
ProtoContext::update_now(); proto_context.update_now();
ProtoContext::reset(cookie_psid); proto_context.reset(cookie_psid);
ProtoContext::set_local_peer_id(local_peer_id); proto_context.set_local_peer_id(local_peer_id);
ProtoContext::start(cookie_psid); proto_context.start(cookie_psid);
ProtoContext::flush(true); proto_context.flush(true);
// coarse wakeup range // coarse wakeup range
housekeeping_schedule.init(Time::Duration::binary_ms(512), Time::Duration::binary_ms(1024)); housekeeping_schedule.init(Time::Duration::binary_ms(512), Time::Duration::binary_ms(1024));
@ -182,8 +179,8 @@ class ServerProto
ManLink::send->stats_notify(TransportLink::send->stats_poll(), true); ManLink::send->stats_notify(TransportLink::send->stats_poll(), true);
} }
ProtoContext::pre_destroy(); proto_context.pre_destroy();
ProtoContext::reset_dc_factory(); proto_context.reset_dc_factory();
if (TransportLink::send) if (TransportLink::send)
{ {
TransportLink::send->stop(); TransportLink::send->stop();
@ -206,23 +203,23 @@ class ServerProto
virtual bool transport_recv(BufferAllocated &buf) override virtual bool transport_recv(BufferAllocated &buf) override
{ {
bool ret = false; bool ret = false;
if (!ProtoContext::primary_defined()) if (!proto_context.primary_defined())
return false; return false;
try try
{ {
OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport RECV[" << buf.size() << "] " << client_endpoint_render() << ' ' << ProtoContext::dump_packet(buf)); OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport RECV[" << buf.size() << "] " << client_endpoint_render() << ' ' << proto_context.dump_packet(buf));
// update current time // update current time
ProtoContext::update_now(); proto_context.update_now();
// get packet type // get packet type
ProtoContext::PacketType pt = ProtoContext::packet_type(buf); ProtoContext::PacketType pt = proto_context.packet_type(buf);
// process packet // process packet
if (pt.is_data()) if (pt.is_data())
{ {
// data packet // data packet
ret = ProtoContext::data_decrypt(pt, buf); ret = proto_context.data_decrypt(pt, buf);
if (buf.size()) if (buf.size())
{ {
#ifdef OPENVPN_PACKET_LOG #ifdef OPENVPN_PACKET_LOG
@ -237,15 +234,15 @@ class ServerProto
} }
// do a lightweight flush // do a lightweight flush
ProtoContext::flush(false); proto_context.flush(false);
} }
else if (pt.is_control()) else if (pt.is_control())
{ {
// control packet // control packet
ret = ProtoContext::control_net_recv(pt, std::move(buf)); ret = proto_context.control_net_recv(pt, std::move(buf));
// do a full flush // do a full flush
ProtoContext::flush(true); proto_context.flush(true);
} }
// schedule housekeeping wakeup // schedule housekeeping wakeup
@ -269,7 +266,7 @@ class ServerProto
// Return true if keepalive parameter(s) are enabled. // Return true if keepalive parameter(s) are enabled.
virtual bool is_keepalive_enabled() const override virtual bool is_keepalive_enabled() const override
{ {
return ProtoContext::is_keepalive_enabled(); return proto_context.is_keepalive_enabled();
} }
// Disable keepalive for rest of session, but fetch // Disable keepalive for rest of session, but fetch
@ -278,7 +275,7 @@ class ServerProto
virtual void disable_keepalive(unsigned int &keepalive_ping, virtual void disable_keepalive(unsigned int &keepalive_ping,
unsigned int &keepalive_timeout) override unsigned int &keepalive_timeout) override
{ {
ProtoContext::disable_keepalive(keepalive_ping, keepalive_timeout); proto_context.disable_keepalive(keepalive_ping, keepalive_timeout);
if (ManLink::send) if (ManLink::send)
ManLink::send->keepalive_override(keepalive_ping, keepalive_timeout); ManLink::send->keepalive_override(keepalive_ping, keepalive_timeout);
} }
@ -286,7 +283,7 @@ class ServerProto
// override the data channel factory // override the data channel factory
virtual void override_dc_factory(const CryptoDCFactory::Ptr &dc_factory) override virtual void override_dc_factory(const CryptoDCFactory::Ptr &dc_factory) override
{ {
ProtoContext::dc_settings().set_factory(dc_factory); proto_context.dc_settings().set_factory(dc_factory);
} }
virtual ~Session() virtual ~Session()
@ -301,7 +298,7 @@ class ServerProto
const Factory &factory, const Factory &factory,
ManClientInstance::Factory::Ptr man_factory_arg, ManClientInstance::Factory::Ptr man_factory_arg,
TunClientInstance::Factory::Ptr tun_factory_arg) TunClientInstance::Factory::Ptr tun_factory_arg)
: ProtoContext(factory.clone_proto_config(), factory.stats), : proto_context(this, factory.clone_proto_config(), factory.stats),
housekeeping_timer(io_context_arg), housekeeping_timer(io_context_arg),
disconnect_at(Time::infinite()), disconnect_at(Time::infinite()),
stats(factory.stats), stats(factory.stats),
@ -318,11 +315,11 @@ class ServerProto
// proto base class calls here for control channel network sends // proto base class calls here for control channel network sends
virtual void control_net_send(const Buffer &net_buf) override virtual void control_net_send(const Buffer &net_buf) override
{ {
OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport SEND[" << net_buf.size() << "] " << client_endpoint_render() << ' ' << ProtoContext::dump_packet(net_buf)); OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport SEND[" << net_buf.size() << "] " << client_endpoint_render() << ' ' << proto_context.dump_packet(net_buf));
if (TransportLink::send) if (TransportLink::send)
{ {
if (TransportLink::send->transport_send_const(net_buf)) if (TransportLink::send->transport_send_const(net_buf))
ProtoContext::update_last_sent(); proto_context.update_last_sent();
} }
} }
@ -353,7 +350,7 @@ class ServerProto
if (msg == "PUSH_REQUEST") if (msg == "PUSH_REQUEST")
{ {
if (get_management()) if (get_management())
ManLink::send->push_request(ProtoContext::conf_ptr()); ManLink::send->push_request(proto_context.conf_ptr());
else else
auth_failed("no management provider", ""); auth_failed("no management provider", "");
} }
@ -368,6 +365,14 @@ class ServerProto
} }
} }
void active(bool primary) override
{
/* Currently the server does not do anything special when the connection
* is ready (control channel fully established). We probably should trigger
* sending a PUSH_REPLY here, when the client requested it via
* IV_PROTO_REQUEST_PUSH instead waiting for an explicit PUSH_REQUEST */
}
virtual void auth_failed(const std::string &reason, virtual void auth_failed(const std::string &reason,
const std::string &client_reason) override const std::string &client_reason) override
{ {
@ -379,7 +384,7 @@ class ServerProto
if (halt || disconnect_type == DT_HALT_RESTART) if (halt || disconnect_type == DT_HALT_RESTART)
return; return;
ProtoContext::update_now(); proto_context.update_now();
if (TunLink::send && (disconnect_type < DT_RELAY_TRANSITION)) if (TunLink::send && (disconnect_type < DT_RELAY_TRANSITION))
{ {
@ -388,13 +393,13 @@ class ServerProto
disconnect_in(Time::Duration::seconds(10)); // not a real disconnect, just complete transition to relay disconnect_in(Time::Duration::seconds(10)); // not a real disconnect, just complete transition to relay
} }
if (ProtoContext::primary_defined()) if (proto_context.primary_defined())
{ {
BufferPtr buf(new BufferAllocated(64, 0)); BufferPtr buf(new BufferAllocated(64, 0));
buf_append_string(*buf, "RELAY"); buf_append_string(*buf, "RELAY");
buf->null_terminate(); buf->null_terminate();
ProtoContext::control_send(std::move(buf)); proto_context.control_send(std::move(buf));
ProtoContext::flush(true); proto_context.flush(true);
} }
set_housekeeping_timer(); set_housekeeping_timer();
@ -402,7 +407,7 @@ class ServerProto
virtual void push_reply(std::vector<BufferPtr> &&push_msgs) override virtual void push_reply(std::vector<BufferPtr> &&push_msgs) override
{ {
if (halt || (disconnect_type >= DT_RELAY_TRANSITION) || !ProtoContext::primary_defined()) if (halt || (disconnect_type >= DT_RELAY_TRANSITION) || !proto_context.primary_defined())
return; return;
if (disconnect_type == DT_AUTH_PENDING) if (disconnect_type == DT_AUTH_PENDING)
@ -411,17 +416,17 @@ class ServerProto
cancel_disconnect(); cancel_disconnect();
} }
ProtoContext::update_now(); proto_context.update_now();
if (get_tun()) if (get_tun())
{ {
ProtoContext::init_data_channel(); proto_context.init_data_channel();
for (auto &msg : push_msgs) for (auto &msg : push_msgs)
{ {
msg->null_terminate(); msg->null_terminate();
ProtoContext::control_send(std::move(msg)); proto_context.control_send(std::move(msg));
} }
ProtoContext::flush(true); proto_context.flush(true);
set_housekeeping_timer(); set_housekeeping_timer();
} }
else else
@ -445,7 +450,7 @@ class ServerProto
if (halt || disconnect_type == DT_HALT_RESTART) if (halt || disconnect_type == DT_HALT_RESTART)
return; return;
ProtoContext::update_now(); proto_context.update_now();
BufferPtr buf(new BufferAllocated(128, BufferAllocated::GROW)); BufferPtr buf(new BufferAllocated(128, BufferAllocated::GROW));
BufferStreamOut os(*buf); BufferStreamOut os(*buf);
@ -520,11 +525,11 @@ class ServerProto
OPENVPN_LOG(instance_name() << " : Disconnect: " << ts << ' ' << reason); OPENVPN_LOG(instance_name() << " : Disconnect: " << ts << ' ' << reason);
if (ProtoContext::primary_defined()) if (proto_context.primary_defined())
{ {
buf->null_terminate(); buf->null_terminate();
ProtoContext::control_send(std::move(buf)); proto_context.control_send(std::move(buf));
ProtoContext::flush(true); proto_context.flush(true);
} }
set_housekeeping_timer(); set_housekeeping_timer();
@ -534,7 +539,7 @@ class ServerProto
{ {
if (halt || disconnect_type == DT_HALT_RESTART) if (halt || disconnect_type == DT_HALT_RESTART)
return; return;
ProtoContext::update_now(); proto_context.update_now();
disconnect_in(Time::Duration::seconds(seconds)); disconnect_in(Time::Duration::seconds(seconds));
set_housekeeping_timer(); set_housekeeping_timer();
} }
@ -543,7 +548,7 @@ class ServerProto
{ {
if (halt || (disconnect_type >= DT_RELAY_TRANSITION) || !seconds) if (halt || (disconnect_type >= DT_RELAY_TRANSITION) || !seconds)
return; return;
ProtoContext::update_now(); proto_context.update_now();
disconnect_type = DT_AUTH_PENDING; disconnect_type = DT_AUTH_PENDING;
disconnect_in(Time::Duration::seconds(seconds)); disconnect_in(Time::Duration::seconds(seconds));
set_housekeeping_timer(); set_housekeeping_timer();
@ -551,41 +556,41 @@ class ServerProto
virtual void post_cc_msg(BufferPtr &&msg) override virtual void post_cc_msg(BufferPtr &&msg) override
{ {
if (halt || !ProtoContext::primary_defined()) if (halt || !proto_context.primary_defined())
return; return;
ProtoContext::update_now(); proto_context.update_now();
msg->null_terminate(); msg->null_terminate();
ProtoContext::control_send(std::move(msg)); proto_context.control_send(std::move(msg));
ProtoContext::flush(true); proto_context.flush(true);
set_housekeeping_timer(); set_housekeeping_timer();
} }
virtual void stats_notify(const PeerStats &ps, const bool final) override void stats_notify(const PeerStats &ps, const bool final) override
{ {
if (ManLink::send) if (ManLink::send)
ManLink::send->stats_notify(ps, final); ManLink::send->stats_notify(ps, final);
} }
virtual void float_notify(const PeerAddr::Ptr &addr) override void float_notify(const PeerAddr::Ptr &addr) override
{ {
if (ManLink::send) if (ManLink::send)
ManLink::send->float_notify(addr); ManLink::send->float_notify(addr);
} }
virtual void ipma_notify(const struct ovpn_tun_head_ipma &ipma) override void ipma_notify(const struct ovpn_tun_head_ipma &ipma) override
{ {
if (ManLink::send) if (ManLink::send)
ManLink::send->ipma_notify(ipma); ManLink::send->ipma_notify(ipma);
} }
virtual void data_limit_notify(const int key_id, void data_limit_notify(const int key_id,
const DataLimit::Mode cdl_mode, const DataLimit::Mode cdl_mode,
const DataLimit::State cdl_status) override const DataLimit::State cdl_status) override
{ {
ProtoContext::update_now(); proto_context.update_now();
ProtoContext::data_limit_notify(key_id, cdl_mode, cdl_status); proto_context.data_limit_notify(key_id, cdl_mode, cdl_status);
ProtoContext::flush(true); proto_context.flush(true);
set_housekeeping_timer(); set_housekeeping_timer();
} }
@ -613,7 +618,7 @@ class ServerProto
// and set_housekeeping_timer() called after this method // and set_housekeeping_timer() called after this method
void disconnect_in(const Time::Duration &dur) void disconnect_in(const Time::Duration &dur)
{ {
disconnect_at = now() + dur; disconnect_at = proto_context.now() + dur;
} }
void cancel_disconnect() void cancel_disconnect()
@ -628,13 +633,13 @@ class ServerProto
if (!e && !halt) if (!e && !halt)
{ {
// update current time // update current time
ProtoContext::update_now(); proto_context.update_now();
housekeeping_schedule.reset(); housekeeping_schedule.reset();
ProtoContext::housekeeping(); proto_context.housekeeping();
if (ProtoContext::invalidated()) if (proto_context.invalidated())
invalidation_error(ProtoContext::invalidation_reason()); invalidation_error(proto_context.invalidation_reason());
else if (now() >= disconnect_at) else if (proto_context.now() >= disconnect_at)
{ {
switch (disconnect_type) switch (disconnect_type)
{ {
@ -642,7 +647,7 @@ class ServerProto
error("disconnect triggered"); error("disconnect triggered");
break; break;
case DT_RELAY_TRANSITION: case DT_RELAY_TRANSITION:
ProtoContext::pre_destroy(); proto_context.pre_destroy();
break; break;
case DT_AUTH_PENDING: case DT_AUTH_PENDING:
auth_failed("Auth Pending Timeout", "Auth Pending Timeout"); auth_failed("Auth Pending Timeout", "Auth Pending Timeout");
@ -664,13 +669,13 @@ class ServerProto
void set_housekeeping_timer() void set_housekeeping_timer()
{ {
Time next = ProtoContext::next_housekeeping(); Time next = proto_context.next_housekeeping();
next.min(disconnect_at); next.min(disconnect_at);
if (!housekeeping_schedule.similar(next)) if (!housekeeping_schedule.similar(next))
{ {
if (!next.is_infinite()) if (!next.is_infinite())
{ {
next.max(now()); next.max(proto_context.now());
housekeeping_schedule.reset(next); housekeeping_schedule.reset(next);
housekeeping_timer.expires_at(next); housekeeping_timer.expires_at(next);
housekeeping_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error) housekeeping_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error)
@ -738,6 +743,8 @@ class ServerProto
DT_RELAY_TRANSITION, DT_RELAY_TRANSITION,
DT_HALT_RESTART, DT_HALT_RESTART,
}; };
ProtoContext proto_context;
int disconnect_type = DT_NONE; int disconnect_type = DT_NONE;
bool preserve_session_id = true; bool preserve_session_id = true;

View File

@ -168,6 +168,53 @@ enum
} // namespace } // namespace
} // namespace proto_context_private } // namespace proto_context_private
class ProtoContextCallbackInterface
{
public:
/**
* Sends out bytes to the network.
*/
virtual void control_net_send(const Buffer &net_buf) = 0;
/*
* Receive as packet from the network
* \note app may take ownership of app_bp via std::move
*/
virtual void control_recv(BufferPtr &&app_bp) = 0;
/** Called on client to request username/password credentials.
* Should be overridden by derived class if credentials are required.
* username and password should be written into buf with write_auth_string().
*/
virtual void client_auth(Buffer &buf)
{
write_empty_string(buf); // username
write_empty_string(buf); // password
}
/** Called on server with credentials and peer info provided by client.
*Should be overriden by derived class if credentials are required. */
virtual void server_auth(const std::string &username,
const SafeString &password,
const std::string &peer_info,
const AuthCert::Ptr &auth_cert)
{
}
/**
* Writes an empty user or password string for the key-method 2 packet in the OpenVPN protocol
* @param buf buffer to write to
*/
static void write_empty_string(Buffer &buf)
{
uint8_t empty[]{0x00, 0x00}; // empty length field without content
buf.write(&empty, 2);
}
//! Called when KeyContext transitions to ACTIVE state
virtual void active(bool primary) = 0;
};
class ProtoContext class ProtoContext
{ {
protected: protected:
@ -1388,9 +1435,8 @@ class ProtoContext
return out.str(); return out.str();
} }
protected: // used for reading/writing authentication strings (username, password, etc.) from buffer using the
// used for reading/writing authentication strings (username, password, etc.) // 2 byte prefix for length
static void write_uint16_length(const size_t size, Buffer &buf) static void write_uint16_length(const size_t size, Buffer &buf)
{ {
if (size > 0xFFFF) if (size > 0xFFFF)
@ -1446,6 +1492,11 @@ class ProtoContext
buf.null_terminate(); buf.null_terminate();
} }
static void write_empty_string(Buffer &buf)
{
write_uint16_length(0, buf);
}
template <typename S> template <typename S>
static S read_control_string(const Buffer &buf) static S read_control_string(const Buffer &buf)
{ {
@ -1469,17 +1520,7 @@ class ProtoContext
control_send(std::move(bp)); control_send(std::move(bp));
} }
static unsigned char *skip_string(Buffer &buf) protected:
{
const size_t len = read_uint16_length(buf);
return buf.read_alloc(len);
}
static void write_empty_string(Buffer &buf)
{
write_uint16_length(0, buf);
}
// Packet structure for managing network packets, passed as a template // Packet structure for managing network packets, passed as a template
// parameter to ProtoStackBase // parameter to ProtoStackBase
class Packet class Packet
@ -2785,7 +2826,7 @@ class ProtoContext
const std::string username = read_auth_string<std::string>(*buf); const std::string username = read_auth_string<std::string>(*buf);
const SafeString password = read_auth_string<SafeString>(*buf); const SafeString password = read_auth_string<SafeString>(*buf);
const std::string peer_info = read_auth_string<std::string>(*buf); const std::string peer_info = read_auth_string<std::string>(*buf);
proto.server_auth(username, password, peer_info, Base::auth_cert()); proto.proto_callback->server_auth(username, password, peer_info, Base::auth_cert());
} }
} }
@ -3641,9 +3682,11 @@ class ProtoContext
OPENVPN_SIMPLE_EXCEPTION(select_key_context_error); OPENVPN_SIMPLE_EXCEPTION(select_key_context_error);
ProtoContext(const ProtoConfig::Ptr &config_arg, // configuration ProtoContext(ProtoContextCallbackInterface *cb_arg,
const ProtoConfig::Ptr &config_arg, // configuration
const SessionStats::Ptr &stats_arg) // error stats const SessionStats::Ptr &stats_arg) // error stats
: config(config_arg), : proto_callback(cb_arg),
config(config_arg),
stats(stats_arg), stats(stats_arg),
mode_(config_arg->ssl_factory->mode()), mode_(config_arg->ssl_factory->mode()),
n_key_ids(0), n_key_ids(0),
@ -4263,8 +4306,13 @@ class ProtoContext
return *stats; return *stats;
} }
protected:
// debugging // debugging
bool is_state_client_wait_reset_ack() const
{
return primary_state() == C_WAIT_RESET_ACK;
}
protected:
int primary_state() const int primary_state() const
{ {
if (primary) if (primary)
@ -4291,32 +4339,11 @@ class ProtoContext
secondary.reset(); secondary.reset();
} }
virtual void control_net_send(const Buffer &net_buf) = 0;
// app may take ownership of app_bp via std::move
virtual void control_recv(BufferPtr &&app_bp) = 0;
// Called on client to request username/password credentials. // Called on client to request username/password credentials.
// Should be overriden by derived class if credentials are required. // delegated to the callback/parent
// username and password should be written into buf with write_auth_string(). void client_auth(Buffer &buf)
virtual void client_auth(Buffer &buf)
{
write_empty_string(buf); // username
write_empty_string(buf); // password
}
// Called on server with credentials and peer info provided by client.
// Should be overriden by derived class if credentials are required.
virtual void server_auth(const std::string &username,
const SafeString &password,
const std::string &peer_info,
const AuthCert::Ptr &auth_cert)
{
}
// Called when KeyContext transitions to ACTIVE state
virtual void active(bool primary)
{ {
proto_callback->client_auth(buf);
} }
void update_last_received() void update_last_received()
@ -4326,12 +4353,12 @@ class ProtoContext
void net_send(const unsigned int key_id, const Packet &net_pkt) void net_send(const unsigned int key_id, const Packet &net_pkt)
{ {
control_net_send(net_pkt.buffer()); proto_callback->control_net_send(net_pkt.buffer());
} }
void app_recv(const unsigned int key_id, BufferPtr &&to_app_buf) void app_recv(const unsigned int key_id, BufferPtr &&to_app_buf)
{ {
control_recv(std::move(to_app_buf)); proto_callback->control_recv(std::move(to_app_buf));
} }
// we're getting a request from peer to renegotiate. // we're getting a request from peer to renegotiate.
@ -4471,7 +4498,7 @@ class ProtoContext
case KeyContext::KEV_ACTIVE: case KeyContext::KEV_ACTIVE:
OPENVPN_LOG_PROTO_VERBOSE(debug_prefix() << " SESSION_ACTIVE"); OPENVPN_LOG_PROTO_VERBOSE(debug_prefix() << " SESSION_ACTIVE");
primary->rekey(CryptoDCInstance::ACTIVATE_PRIMARY); primary->rekey(CryptoDCInstance::ACTIVATE_PRIMARY);
active(true); proto_callback->active(true);
break; break;
case KeyContext::KEV_RENEGOTIATE: case KeyContext::KEV_RENEGOTIATE:
case KeyContext::KEV_RENEGOTIATE_FORCE: case KeyContext::KEV_RENEGOTIATE_FORCE:
@ -4511,7 +4538,7 @@ class ProtoContext
secondary->rekey(CryptoDCInstance::NEW_SECONDARY); secondary->rekey(CryptoDCInstance::NEW_SECONDARY);
if (primary) if (primary)
primary->prepare_expire(); primary->prepare_expire();
active(false); proto_callback->active(false);
break; break;
case KeyContext::KEV_BECOME_PRIMARY: case KeyContext::KEV_BECOME_PRIMARY:
if (!secondary->invalidated()) if (!secondary->invalidated())
@ -4590,6 +4617,13 @@ class ProtoContext
// BEGIN ProtoContext data members // BEGIN ProtoContext data members
/** the class that uses this class needs to be called back on a few things. Typically a class
* that uses this class as field for composition. This parent/callback class needs to ensure that
* it lives longer than this class, e.g. by having this class as field as this class blindly
* assumes that this pointer is always valid for its lifetime
*/
ProtoContextCallbackInterface *proto_callback;
ProtoConfig::Ptr config; ProtoConfig::Ptr config;
SessionStats::Ptr stats; SessionStats::Ptr stats;

View File

@ -353,24 +353,20 @@ class DroughtMeasure
}; };
// test the OpenVPN protocol implementation in ProtoContext // test the OpenVPN protocol implementation in ProtoContext
class TestProto : public ProtoContext class TestProto : public ProtoContextCallbackInterface
{ {
typedef ProtoContext Base; /* Callback methods that are not used */
void active(bool primary) override
{
}
using Base::is_server;
using Base::mode;
using Base::now;
public: public:
using Base::flush;
typedef Base::PacketType PacketType;
OPENVPN_EXCEPTION(session_invalidated); OPENVPN_EXCEPTION(session_invalidated);
TestProto(const Base::ProtoConfig::Ptr &config, TestProto(const ProtoContext::ProtoConfig::Ptr &config,
const SessionStats::Ptr &stats) const SessionStats::Ptr &stats)
: Base(config, stats), : proto_context(this, config, stats),
control_drought("control", config->now), control_drought("control", config->now),
data_drought("data", config->now), data_drought("data", config->now),
frame(config->frame) frame(config->frame)
@ -382,27 +378,26 @@ class TestProto : public ProtoContext
void reset() void reset()
{ {
net_out.clear(); net_out.clear();
Base::reset(); proto_context.reset();
proto_context.conf().mss_parms.mssfix = MSSParms::MSSFIX_DEFAULT;
Base::conf().mss_parms.mssfix = MSSParms::MSSFIX_DEFAULT;
} }
void initial_app_send(const char *msg) void initial_app_send(const char *msg)
{ {
Base::start(); proto_context.start();
const size_t msglen = std::strlen(msg) + 1; const size_t msglen = std::strlen(msg) + 1;
BufferAllocated app_buf((unsigned char *)msg, msglen, 0); BufferAllocated app_buf((unsigned char *)msg, msglen, 0);
copy_progress(app_buf); copy_progress(app_buf);
control_send(std::move(app_buf)); control_send(std::move(app_buf));
flush(true); proto_context.flush(true);
} }
void app_send_templ_init(const char *msg) void app_send_templ_init(const char *msg)
{ {
Base::start(); proto_context.start();
const size_t msglen = std::strlen(msg) + 1; const size_t msglen = std::strlen(msg) + 1;
templ.reset(new BufferAllocated((unsigned char *)msg, msglen, 0)); templ.reset(new BufferAllocated((unsigned char *)msg, msglen, 0));
flush(true); proto_context.flush(true);
} }
void app_send_templ() void app_send_templ()
@ -421,9 +416,9 @@ class TestProto : public ProtoContext
bool do_housekeeping() bool do_housekeeping()
{ {
if (now() >= Base::next_housekeeping()) if (proto_context.now() >= proto_context.next_housekeeping())
{ {
Base::housekeeping(); proto_context.housekeeping();
return true; return true;
} }
else else
@ -433,13 +428,13 @@ class TestProto : public ProtoContext
void control_send(BufferPtr &&app_bp) void control_send(BufferPtr &&app_bp)
{ {
app_bytes_ += app_bp->size(); app_bytes_ += app_bp->size();
Base::control_send(std::move(app_bp)); proto_context.control_send(std::move(app_bp));
} }
void control_send(BufferAllocated &&app_buf) void control_send(BufferAllocated &&app_buf)
{ {
app_bytes_ += app_buf.size(); app_bytes_ += app_buf.size();
Base::control_send(std::move(app_buf)); proto_context.control_send(std::move(app_buf));
} }
BufferPtr data_encrypt_string(const char *str) BufferPtr data_encrypt_string(const char *str)
@ -453,12 +448,12 @@ class TestProto : public ProtoContext
void data_encrypt(BufferAllocated &in_out) void data_encrypt(BufferAllocated &in_out)
{ {
Base::data_encrypt(in_out); proto_context.data_encrypt(in_out);
} }
void data_decrypt(const PacketType &type, BufferAllocated &in_out) void data_decrypt(const ProtoContext::PacketType &type, BufferAllocated &in_out)
{ {
Base::data_decrypt(type, in_out); proto_context.data_decrypt(type, in_out);
if (in_out.size()) if (in_out.size())
{ {
data_bytes_ += in_out.size(); data_bytes_ += in_out.size();
@ -500,13 +495,8 @@ class TestProto : public ProtoContext
void check_invalidated() void check_invalidated()
{ {
if (Base::invalidated()) if (proto_context.invalidated())
throw session_invalidated(Error::name(Base::invalidation_reason())); throw session_invalidated(Error::name(proto_context.invalidation_reason()));
}
bool is_state_client_wait_reset_ack() const
{
return primary_state() == C_WAIT_RESET_ACK;
} }
void disable_xmit() void disable_xmit()
@ -514,13 +504,15 @@ class TestProto : public ProtoContext
disable_xmit_ = true; disable_xmit_ = true;
} }
ProtoContext proto_context;
std::deque<BufferPtr> net_out; std::deque<BufferPtr> net_out;
DroughtMeasure control_drought; DroughtMeasure control_drought;
DroughtMeasure data_drought; DroughtMeasure data_drought;
private: private:
virtual void control_net_send(const Buffer &net_buf) void control_net_send(const Buffer &net_buf) override
{ {
if (disable_xmit_) if (disable_xmit_)
return; return;
@ -528,7 +520,7 @@ class TestProto : public ProtoContext
net_out.push_back(BufferPtr(new BufferAllocated(net_buf, 0))); net_out.push_back(BufferPtr(new BufferAllocated(net_buf, 0)));
} }
virtual void control_recv(BufferPtr &&app_bp) void control_recv(BufferPtr &&app_bp) override
{ {
BufferPtr work; BufferPtr work;
work.swap(app_bp); work.swap(app_bp);
@ -559,7 +551,7 @@ class TestProto : public ProtoContext
void modmsg(BufferPtr &buf) void modmsg(BufferPtr &buf)
{ {
char *msg = (char *)buf->data(); char *msg = (char *)buf->data();
if (is_server()) if (proto_context.is_server())
{ {
msg[8] = 'S'; msg[8] = 'S';
msg[11] = 'C'; msg[11] = 'C';
@ -602,9 +594,9 @@ class TestProtoClient : public TestProto
typedef TestProto Base; typedef TestProto Base;
public: public:
TestProtoClient(const Base::ProtoConfig::Ptr &config, TestProtoClient(const ProtoContext::ProtoConfig::Ptr &config,
const SessionStats::Ptr &stats) const SessionStats::Ptr &stats)
: Base(config, stats) : TestProto(config, stats)
{ {
} }
@ -613,21 +605,26 @@ class TestProtoClient : public TestProto
{ {
const std::string username("foo"); const std::string username("foo");
const std::string password("bar"); const std::string password("bar");
Base::write_auth_string(username, buf); ProtoContext::write_auth_string(username, buf);
Base::write_auth_string(password, buf); ProtoContext::write_auth_string(password, buf);
} }
}; };
class TestProtoServer : public TestProto class TestProtoServer : public TestProto
{ {
typedef TestProto Base;
public: public:
void start()
{
proto_context.start();
}
OPENVPN_SIMPLE_EXCEPTION(auth_failed); OPENVPN_SIMPLE_EXCEPTION(auth_failed);
TestProtoServer(const Base::ProtoConfig::Ptr &config,
TestProtoServer(const ProtoContext::ProtoConfig::Ptr &config,
const SessionStats::Ptr &stats) const SessionStats::Ptr &stats)
: Base(config, stats) : TestProto(config, stats)
{ {
} }
@ -687,7 +684,7 @@ class NoisyWire
a.app_send_templ(); a.app_send_templ();
// queue a data channel packet // queue a data channel packet
if (a.data_channel_ready()) if (a.proto_context.data_channel_ready())
{ {
BufferPtr bp = a.data_encrypt_string("Waiting for godot A... Waiting for godot B... Waiting for godot C... Waiting for godot D... Waiting for godot E... Waiting for godot F... Waiting for godot G... Waiting for godot H... Waiting for godot I... Waiting for godot J..."); BufferPtr bp = a.data_encrypt_string("Waiting for godot A... Waiting for godot B... Waiting for godot C... Waiting for godot D... Waiting for godot E... Waiting for godot F... Waiting for godot G... Waiting for godot H... Waiting for godot I... Waiting for godot J...");
wire.push_back(bp); wire.push_back(bp);
@ -710,14 +707,14 @@ class NoisyWire
BufferPtr bp = recv(); BufferPtr bp = recv();
if (!bp) if (!bp)
break; break;
typename T2::PacketType pt = b.packet_type(*bp); typename ProtoContext::PacketType pt = b.proto_context.packet_type(*bp);
if (pt.is_control()) if (pt.is_control())
{ {
#ifdef VERBOSE #ifdef VERBOSE
if (!b.control_net_validate(pt, *bp)) // not strictly necessary since control_net_recv will also validate if (!b.control_net_validate(pt, *bp)) // not strictly necessary since control_net_recv will also validate
std::cout << now->raw() << " " << title << " CONTROL PACKET VALIDATION FAILED" << std::endl; std::cout << now->raw() << " " << title << " CONTROL PACKET VALIDATION FAILED" << std::endl;
#endif #endif
b.control_net_recv(pt, std::move(bp)); b.proto_context.control_net_recv(pt, std::move(bp));
} }
else if (pt.is_data()) else if (pt.is_data())
{ {
@ -744,11 +741,11 @@ class NoisyWire
#ifdef VERBOSE #ifdef VERBOSE
std::cout << now->raw() << " " << title << " KEY_STATE_ERROR" << std::endl; std::cout << now->raw() << " " << title << " KEY_STATE_ERROR" << std::endl;
#endif #endif
b.stat().error(Error::KEY_STATE_ERROR); b.proto_context.stat().error(Error::KEY_STATE_ERROR);
} }
#ifdef SIMULATE_UDP_AMPLIFY_ATTACK #ifdef SIMULATE_UDP_AMPLIFY_ATTACK
if (b.is_state_client_wait_reset_ack()) if (b.proto_context.is_state_client_wait_reset_ack())
{ {
b.disable_xmit(); b.disable_xmit();
#ifdef VERBOSE #ifdef VERBOSE
@ -757,7 +754,7 @@ class NoisyWire
} }
#endif #endif
} }
b.flush(true); b.proto_context.flush(true);
} }
private: private:
@ -1148,8 +1145,8 @@ int test(const int thread_num)
<< " CTRL=" << cli_proto.n_control_recv() << '/' << cli_proto.n_control_send() << '/' << serv_proto.n_control_recv() << '/' << serv_proto.n_control_send() << " CTRL=" << cli_proto.n_control_recv() << '/' << cli_proto.n_control_send() << '/' << serv_proto.n_control_recv() << '/' << serv_proto.n_control_send()
#endif #endif
<< " D=" << cli_proto.control_drought().raw() << '/' << cli_proto.data_drought().raw() << '/' << serv_proto.control_drought().raw() << '/' << serv_proto.data_drought().raw() << " D=" << cli_proto.control_drought().raw() << '/' << cli_proto.data_drought().raw() << '/' << serv_proto.control_drought().raw() << '/' << serv_proto.data_drought().raw()
<< " N=" << cli_proto.negotiations() << '/' << serv_proto.negotiations() << " N=" << cli_proto.proto_context.negotiations() << '/' << serv_proto.proto_context.negotiations()
<< " SH=" << cli_proto.slowest_handshake().raw() << '/' << serv_proto.slowest_handshake().raw() << " SH=" << cli_proto.proto_context.slowest_handshake().raw() << '/' << serv_proto.proto_context.slowest_handshake().raw()
<< " HE=" << cli_stats->get_error_count(Error::HANDSHAKE_TIMEOUT) << '/' << serv_stats->get_error_count(Error::HANDSHAKE_TIMEOUT) << " HE=" << cli_stats->get_error_count(Error::HANDSHAKE_TIMEOUT) << '/' << serv_stats->get_error_count(Error::HANDSHAKE_TIMEOUT)
<< std::endl; << std::endl;