0
0
mirror of https://github.com/OpenVPN/openvpn3.git synced 2024-09-19 19:52:15 +02:00

Refactor ClientProto::Session to use ProtoContext as field insatead of Base

Currently the protocontext is used as kind of composition but not really
and makes following the code harder, since this inheritance not only serves
for composition but also as callbacks through virtual method inheritance.

Making ProtoContext a normal field and definining a callback interface makes
the class relationship easier to understand.

Signed-off-by: Arne Schwabe <arne@openvpn.net>
This commit is contained in:
Arne Schwabe 2024-02-22 07:26:46 +01:00 committed by Jenkins-dev
parent 2ff8029eba
commit e14c3f0441
5 changed files with 282 additions and 249 deletions

View File

@ -1164,7 +1164,7 @@ class ClientOptions : public RC<thread_unsafe_refcount>
// Copy ProtoConfig so that modifications due to server push will
// not persist across client instantiations.
cli_config->proto_context_config.reset(new Client::ProtoConfig(proto_config_cached(relay_mode)));
cli_config->proto_context_config.reset(new ProtoContext::ProtoConfig(proto_config_cached(relay_mode)));
cli_config->proto_context_options = proto_context_options;
cli_config->push_base = push_base;
@ -1274,7 +1274,7 @@ class ClientOptions : public RC<thread_unsafe_refcount>
}
private:
Client::ProtoConfig &proto_config_cached(const bool relay_mode)
ProtoContext::ProtoConfig &proto_config_cached(const bool relay_mode)
{
if (relay_mode && cp_relay)
return *cp_relay;
@ -1282,14 +1282,14 @@ class ClientOptions : public RC<thread_unsafe_refcount>
return *cp_main;
}
Client::ProtoConfig::Ptr proto_config(const OptionList &opt,
const Config &config,
const ParseClientConfig &pcc,
const bool relay_mode)
ProtoContext::ProtoConfig::Ptr proto_config(const OptionList &opt,
const Config &config,
const ParseClientConfig &pcc,
const bool relay_mode)
{
// relay mode is null unless one of the below directives is defined
if (relay_mode && !opt.exists("relay-mode"))
return Client::ProtoConfig::Ptr();
return ProtoContext::ProtoConfig::Ptr();
// load flags
unsigned int lflags = SSLConfigAPI::LF_PARSE_MODE;
@ -1314,7 +1314,7 @@ class ClientOptions : public RC<thread_unsafe_refcount>
cc->set_tls_ciphersuite_list(config.clientconf.tlsCiphersuitesList);
// client ProtoContext config
Client::ProtoConfig::Ptr cp(new Client::ProtoConfig());
ProtoContext::ProtoConfig::Ptr cp(new ProtoContext::ProtoConfig());
cp->ssl_factory = cc->new_factory();
cp->relay_mode = relay_mode;
cp->dc.set_factory(new CryptoDCSelect<SSLLib::CryptoAPI>(cp->ssl_factory->libctx(), frame, cli_stats, rng));
@ -1454,8 +1454,8 @@ class ClientOptions : public RC<thread_unsafe_refcount>
RandomAPI::Ptr prng;
Frame::Ptr frame;
Layer layer;
Client::ProtoConfig::Ptr cp_main;
Client::ProtoConfig::Ptr cp_relay;
ProtoContext::ProtoConfig::Ptr cp_main;
ProtoContext::ProtoConfig::Ptr cp_relay;
RemoteList::Ptr remote_list;
bool server_addr_float;
TransportClientFactory::Ptr transport_factory;

View File

@ -92,20 +92,13 @@ struct NotifyCallback
}
};
class Session : ProtoContext,
class Session : ProtoContextCallbackInterface,
TransportClientParent,
TunClientParent,
public RC<thread_unsafe_refcount>
{
typedef ProtoContext Base;
typedef Base::PacketType PacketType;
using Base::now;
using Base::stat;
public:
typedef RCPtr<Session> Ptr;
typedef Base::ProtoConfig ProtoConfig;
OPENVPN_EXCEPTION(client_exception);
OPENVPN_EXCEPTION(client_halt_restart);
@ -133,7 +126,7 @@ class Session : ProtoContext,
{
}
ProtoConfig::Ptr proto_context_config;
ProtoContext::ProtoConfig::Ptr proto_context_config;
ProtoContextCompressionOptions::Ptr proto_context_options;
PushOptionsBase::Ptr push_base;
TransportClientFactory::Ptr transport_factory;
@ -152,7 +145,7 @@ class Session : ProtoContext,
Session(openvpn_io::io_context &io_context_arg,
const Config &config,
NotifyCallback *notify_callback_arg)
: Base(config.proto_context_config, config.cli_stats),
: proto_context(this, config.proto_context_config, config.cli_stats),
io_context(io_context_arg),
transport_factory(config.transport_factory),
tun_factory(config.tun_factory),
@ -178,9 +171,9 @@ class Session : ProtoContext,
if (!packet_log)
OPENVPN_THROW(open_file_error, "cannot open packet log for output: " << OPENVPN_PACKET_LOG);
#endif
Base::update_now();
Base::reset();
// Base::enable_strict_openvpn_2x();
proto_context.update_now();
proto_context.reset();
// proto_context.enable_strict_openvpn_2x();
info_hold.reset(new std::vector<ClientEvent::Base::Ptr>());
}
@ -194,7 +187,7 @@ class Session : ProtoContext,
{
if (!halt)
{
Base::update_now();
proto_context.update_now();
// coarse wakeup range
housekeeping_schedule.init(Time::Duration::binary_ms(512), Time::Duration::binary_ms(1024));
@ -224,7 +217,7 @@ class Session : ProtoContext,
void send_explicit_exit_notify()
{
if (!halt)
Base::send_explicit_exit_notify();
proto_context.send_explicit_exit_notify();
}
void tun_set_disconnect()
@ -235,22 +228,22 @@ class Session : ProtoContext,
void post_cc_msg(const std::string &msg)
{
Base::update_now();
Base::write_control_string(msg);
Base::flush(true);
proto_context.update_now();
proto_context.write_control_string(msg);
proto_context.flush(true);
set_housekeeping_timer();
}
void post_app_control_message(const std::string proto, const std::string message)
{
if (!conf().app_control_config.supports_protocol(proto))
if (!proto_context.conf().app_control_config.supports_protocol(proto))
{
ClientEvent::Base::Ptr ev = new ClientEvent::UnsupportedFeature{"missing acc protocol support", "server has not announced support of this custom app control protocol", false};
cli_events->add_event(std::move(ev));
return;
}
for (auto fragment : conf().app_control_config.format_message(proto, message))
for (auto fragment : proto_context.conf().app_control_config.format_message(proto, message))
post_cc_msg(std::move(fragment));
}
@ -328,13 +321,13 @@ class Session : ProtoContext,
{
try
{
OPENVPN_LOG_CLIPROTO("Transport RECV " << server_endpoint_render() << ' ' << Base::dump_packet(buf));
OPENVPN_LOG_CLIPROTO("Transport RECV " << server_endpoint_render() << ' ' << proto_context.dump_packet(buf));
// update current time
Base::update_now();
proto_context.update_now();
// update last packet received
stat().update_last_packet_received(now());
proto_context.stat().update_last_packet_received(proto_context.now());
// log connecting event (only on first packet received)
if (!first_packet_received_)
@ -345,13 +338,13 @@ class Session : ProtoContext,
}
// get packet type
Base::PacketType pt = Base::packet_type(buf);
ProtoContext::PacketType pt = proto_context.packet_type(buf);
// process packet
if (pt.is_data())
{
// data packet
Base::data_decrypt(pt, buf);
proto_context.data_decrypt(pt, buf);
if (buf.size())
{
#ifdef OPENVPN_PACKET_LOG
@ -366,15 +359,15 @@ class Session : ProtoContext,
}
// do a lightweight flush
Base::flush(false);
proto_context.flush(false);
}
else if (pt.is_control())
{
// control packet
Base::control_net_recv(pt, std::move(buf));
proto_context.control_net_recv(pt, std::move(buf));
// do a full flush
Base::flush(true);
proto_context.flush(true);
}
else
cli_stats->error(Error::KEY_STATE_ERROR);
@ -412,7 +405,7 @@ class Session : ProtoContext,
OPENVPN_LOG_CLIPROTO("TUN recv, size=" << buf.size());
// update current time
Base::update_now();
proto_context.update_now();
// log packet
#ifdef OPENVPN_PACKET_LOG
@ -432,7 +425,7 @@ class Session : ProtoContext,
// encrypt packet
if (buf.size())
{
const ProtoContext::ProtoConfig &c = Base::conf();
const ProtoContext::ProtoConfig &c = proto_context.conf();
// when calculating mss, we take IPv4 and TCP headers into account
// here we need to add it back since we check the whole IP packet size, not just TCP payload
constexpr size_t MinTcpHeader = 20;
@ -445,13 +438,13 @@ class Session : ProtoContext,
}
else
{
Base::data_encrypt(buf);
proto_context.data_encrypt(buf);
if (buf.size())
{
// send packet via transport to destination
OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << Base::dump_packet(buf));
OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << proto_context.dump_packet(buf));
if (transport->transport_send(buf))
Base::update_last_sent();
proto_context.update_last_sent();
else if (halt)
return;
}
@ -459,7 +452,7 @@ class Session : ProtoContext,
}
// do a lightweight flush
Base::flush(false);
proto_context.flush(false);
// schedule housekeeping wakeup
set_housekeeping_timer();
@ -473,7 +466,7 @@ class Session : ProtoContext,
// Return true if keepalive parameter(s) are enabled.
bool is_keepalive_enabled() const override
{
return Base::is_keepalive_enabled();
return proto_context.is_keepalive_enabled();
}
// Disable keepalive for rest of session, but fetch
@ -481,7 +474,7 @@ class Session : ProtoContext,
void disable_keepalive(unsigned int &keepalive_ping,
unsigned int &keepalive_timeout) override
{
Base::disable_keepalive(keepalive_ping, keepalive_timeout);
proto_context.disable_keepalive(keepalive_ping, keepalive_timeout);
}
void transport_pre_resolve() override
@ -516,9 +509,9 @@ class Session : ProtoContext,
try
{
OPENVPN_LOG("Connecting to " << server_endpoint_render());
Base::set_protocol(transport->transport_protocol());
Base::start();
Base::flush(true);
proto_context.set_protocol(transport->transport_protocol());
proto_context.start();
proto_context.flush(true);
set_housekeeping_timer();
}
catch (const std::exception &e)
@ -587,7 +580,7 @@ class Session : ProtoContext,
OPENVPN_LOG("Session token: [redacted]");
#endif
autologin_sessions = true;
conf().set_xmit_creds(true);
proto_context.conf().set_xmit_creds(true);
creds->set_replace_password_with_session_id(true);
creds->set_session_id(username, sess_id);
}
@ -688,9 +681,9 @@ class Session : ProtoContext,
// proto base class calls here for control channel network sends
void control_net_send(const Buffer &net_buf) override
{
OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << Base::dump_packet(net_buf));
OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << proto_context.dump_packet(net_buf));
if (transport->transport_send_const(net_buf))
Base::update_last_sent();
proto_context.update_last_sent();
}
void recv_auth_failed(const std::string &msg)
@ -753,7 +746,7 @@ class Session : ProtoContext,
{
timeout = clamp_to_typerange<unsigned int>(std::stoul(timeout_str));
// Cap the timeout to end well before renegotiation starts
timeout = std::min(timeout, static_cast<decltype(timeout)>(conf().renegotiate.to_seconds() / 2));
timeout = std::min(timeout, static_cast<decltype(timeout)>(proto_context.conf().renegotiate.to_seconds() / 2));
}
catch (const std::logic_error &)
{
@ -776,7 +769,7 @@ class Session : ProtoContext,
void recv_relay()
{
if (conf().relay_mode)
if (proto_context.conf().relay_mode)
{
fatal_ = Error::RELAY;
fatal_reason_ = "";
@ -818,7 +811,7 @@ class Session : ProtoContext,
// proto base class calls here for app-level control-channel messages received
void control_recv(BufferPtr &&app_bp) override
{
const std::string msg = Unicode::utf8_printable(Base::template read_control_string<std::string>(*app_bp),
const std::string msg = Unicode::utf8_printable(ProtoContext::template read_control_string<std::string>(*app_bp),
Unicode::UTF8_FILTER | Unicode::UTF8_PASS_FMT);
// OPENVPN_LOG("SERVER: " << sanitize_control_message(msg));
@ -859,13 +852,13 @@ class Session : ProtoContext,
void recv_custom_control_message(const std::string msg)
{
bool fullmessage = conf().app_control_recv.receive_message(msg);
bool fullmessage = proto_context.conf().app_control_recv.receive_message(msg);
if (!fullmessage)
return;
auto [proto, app_proto_msg] = conf().app_control_recv.get_message();
auto [proto, app_proto_msg] = proto_context.conf().app_control_recv.get_message();
if (conf().app_control_config.supports_protocol(proto))
if (proto_context.conf().app_control_config.supports_protocol(proto))
{
auto ev = new ClientEvent::AppCustomControlMessage(std::move(proto), std::move(app_proto_msg));
cli_events->add_event(std::move(ev));
@ -897,7 +890,7 @@ class Session : ProtoContext,
<< render_options_sanitized(received_options, Option::RENDER_PASS_FMT | Option::RENDER_NUMBER | Option::RENDER_BRACKET));
// relay servers are not allowed to establish a tunnel with us
if (Base::conf().relay_mode)
if (proto_context.conf().relay_mode)
{
tun_error(Error::RELAY_ERROR, "tunnel not permitted to relay server");
return;
@ -917,28 +910,28 @@ class Session : ProtoContext,
transport_factory->process_push(received_options);
// modify proto config (cipher, auth, key-derivation and compression methods)
Base::process_push(received_options, *proto_context_options);
proto_context.process_push(received_options, *proto_context_options);
// initialize tun/routing
tun = tun_factory->new_tun_client_obj(io_context, *this, transport.get());
tun->tun_start(received_options, *transport, Base::dc_settings());
tun->tun_start(received_options, *transport, proto_context.dc_settings());
// we should be connected at this point
if (!connected_)
throw tun_exception("not connected");
// Propagate tun-mtu back, it might have been overwritten by a pushed tun-mtu option
conf().tun_mtu = tun->vpn_mtu();
proto_context.conf().tun_mtu = tun->vpn_mtu();
// initialize data channel after pushed options have been processed
Base::init_data_channel();
proto_context.init_data_channel();
// we got pushed options and initializated crypto - now we can push mss to dco
tun->adjust_mss(conf().mss_fix);
tun->adjust_mss(proto_context.conf().mss_fix);
// Allow ProtoContext to suggest an alignment adjustment
// hint for transport layer.
transport->reset_align_adjust(Base::align_adjust_hint());
transport->reset_align_adjust(proto_context.align_adjust_hint());
// process "inactive" directive
process_inactive(received_options);
@ -954,10 +947,10 @@ class Session : ProtoContext,
cli_events->add_event(connected_);
// send an event for custom app control if present
if (!conf().app_control_config.supported_protocols.empty())
if (!proto_context.conf().app_control_config.supported_protocols.empty())
{
// Signal support for supported protocols
auto ev = new ClientEvent::AppCustomControlMessage("internal:supported_protocols", string::join(conf().app_control_config.supported_protocols, ":"));
auto ev = new ClientEvent::AppCustomControlMessage("internal:supported_protocols", string::join(proto_context.conf().app_control_config.supported_protocols, ":"));
cli_events->add_event(std::move(ev));
}
@ -1050,27 +1043,27 @@ class Session : ProtoContext,
void client_auth(Buffer &buf) override
{
// we never send creds to a relay server
if (creds && !Base::conf().relay_mode)
if (creds && !proto_context.conf().relay_mode)
{
OPENVPN_LOG("Creds: " << creds->auth_info());
Base::write_auth_string(creds->get_username(), buf);
proto_context.write_auth_string(creds->get_username(), buf);
#ifdef OPENVPN_DISABLE_AUTH_TOKEN // debugging only
if (creds->session_id_defined())
{
OPENVPN_LOG("NOTE: not sending auth-token");
Base::write_empty_string(buf);
ProtoContext::write_empty_string(buf);
}
else
#endif
{
Base::write_auth_string(creds->get_password(), buf);
proto_context.write_auth_string(creds->get_password(), buf);
}
}
else
{
OPENVPN_LOG("Creds: None");
Base::write_empty_string(buf); // username
Base::write_empty_string(buf); // password
write_empty_string(buf); // username
write_empty_string(buf); // password
}
}
@ -1081,7 +1074,7 @@ class Session : ProtoContext,
{
if (!e && !halt && !received_options.partial())
{
Base::update_now();
proto_context.update_now();
if (!sent_push_request)
{
ClientEvent::Base::Ptr ev = new ClientEvent::GetConfig();
@ -1089,8 +1082,8 @@ class Session : ProtoContext,
sent_push_request = true;
}
OPENVPN_LOG("Sending PUSH_REQUEST to server...");
Base::write_control_string(std::string("PUSH_REQUEST"));
Base::flush(true);
proto_context.write_control_string(std::string("PUSH_REQUEST"));
proto_context.flush(true);
set_housekeeping_timer();
{
@ -1134,7 +1127,7 @@ class Session : ProtoContext,
// react to any tls warning triggered during the tls-handshake
virtual void check_tls_warnings()
{
uint32_t tls_warnings = get_tls_warnings();
uint32_t tls_warnings = proto_context.get_tls_warnings();
if (tls_warnings & SSLAPI::TLS_WARN_SIG_MD5)
{
@ -1151,13 +1144,13 @@ class Session : ProtoContext,
void check_proto_warnings()
{
if (uses_bs64_cipher())
if (proto_context.uses_bs64_cipher())
{
ClientEvent::Base::Ptr ev = new ClientEvent::Warn("Proto: Using a 64-bit block cipher that is vulnerable to the SWEET32 attack. Please inform your admin to upgrade to a stronger algorithm. Support for 64-bit block cipher will be dropped in the future.");
cli_events->add_event(std::move(ev));
}
CompressContext::Type comp_type = Base::conf().comp_ctx.type();
CompressContext::Type comp_type = proto_context.conf().comp_ctx.type();
// abort connection if compression is pushed and its support is unannounced
if (comp_type != CompressContext::COMP_STUBv2
@ -1203,15 +1196,15 @@ class Session : ProtoContext,
if (!e && !halt)
{
// update current time
Base::update_now();
proto_context.update_now();
housekeeping_schedule.reset();
Base::housekeeping();
if (Base::invalidated())
proto_context.housekeeping();
if (proto_context.invalidated())
{
if (notify_callback)
{
OPENVPN_LOG("Session invalidated: " << Error::name(Base::invalidation_reason()));
OPENVPN_LOG("Session invalidated: " << Error::name(proto_context.invalidation_reason()));
stop(true);
}
else
@ -1231,12 +1224,12 @@ class Session : ProtoContext,
if (halt)
return;
Time next = Base::next_housekeeping();
Time next = proto_context.next_housekeeping();
if (!housekeeping_schedule.similar(next))
{
if (!next.is_infinite())
{
next.max(now());
next.max(proto_context.now());
housekeeping_schedule.reset(next);
housekeeping_timer.expires_at(next);
housekeeping_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error)
@ -1377,7 +1370,7 @@ class Session : ProtoContext,
void schedule_info_hold_callback()
{
Base::update_now();
proto_context.update_now();
info_hold_timer.expires_after(Time::Duration::seconds(1));
info_hold_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error)
{
@ -1391,7 +1384,7 @@ class Session : ProtoContext,
{
if (!e && !halt)
{
Base::update_now();
proto_context.update_now();
if (info_hold)
{
for (auto &ev : *info_hold)
@ -1420,6 +1413,8 @@ class Session : ProtoContext,
}
#endif
ProtoContext proto_context;
openvpn_io::io_context &io_context;
TransportClientFactory::Ptr transport_factory;

View File

@ -109,16 +109,13 @@ class ServerProto
};
// This is the main server-side client instance object
class Session : ProtoContext, // OpenVPN protocol implementation
public TransportLink, // Transport layer
public TunLink, // Tun/routing layer
public ManLink // Management layer
class Session : ProtoContextCallbackInterface, // Callback interface from protocol implementation
public TransportLink, // Transport layer
public TunLink, // Tun/routing layer
public ManLink // Management layer
{
friend class Factory; // calls constructor
using ProtoContext::now;
using ProtoContext::stat;
public:
typedef RCPtr<Session> Ptr;
@ -142,11 +139,11 @@ class ServerProto
peer_addr = addr;
// init OpenVPN protocol handshake
ProtoContext::update_now();
ProtoContext::reset(cookie_psid);
ProtoContext::set_local_peer_id(local_peer_id);
ProtoContext::start(cookie_psid);
ProtoContext::flush(true);
proto_context.update_now();
proto_context.reset(cookie_psid);
proto_context.set_local_peer_id(local_peer_id);
proto_context.start(cookie_psid);
proto_context.flush(true);
// coarse wakeup range
housekeeping_schedule.init(Time::Duration::binary_ms(512), Time::Duration::binary_ms(1024));
@ -182,8 +179,8 @@ class ServerProto
ManLink::send->stats_notify(TransportLink::send->stats_poll(), true);
}
ProtoContext::pre_destroy();
ProtoContext::reset_dc_factory();
proto_context.pre_destroy();
proto_context.reset_dc_factory();
if (TransportLink::send)
{
TransportLink::send->stop();
@ -206,23 +203,23 @@ class ServerProto
virtual bool transport_recv(BufferAllocated &buf) override
{
bool ret = false;
if (!ProtoContext::primary_defined())
if (!proto_context.primary_defined())
return false;
try
{
OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport RECV[" << buf.size() << "] " << client_endpoint_render() << ' ' << ProtoContext::dump_packet(buf));
OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport RECV[" << buf.size() << "] " << client_endpoint_render() << ' ' << proto_context.dump_packet(buf));
// update current time
ProtoContext::update_now();
proto_context.update_now();
// get packet type
ProtoContext::PacketType pt = ProtoContext::packet_type(buf);
ProtoContext::PacketType pt = proto_context.packet_type(buf);
// process packet
if (pt.is_data())
{
// data packet
ret = ProtoContext::data_decrypt(pt, buf);
ret = proto_context.data_decrypt(pt, buf);
if (buf.size())
{
#ifdef OPENVPN_PACKET_LOG
@ -237,15 +234,15 @@ class ServerProto
}
// do a lightweight flush
ProtoContext::flush(false);
proto_context.flush(false);
}
else if (pt.is_control())
{
// control packet
ret = ProtoContext::control_net_recv(pt, std::move(buf));
ret = proto_context.control_net_recv(pt, std::move(buf));
// do a full flush
ProtoContext::flush(true);
proto_context.flush(true);
}
// schedule housekeeping wakeup
@ -269,7 +266,7 @@ class ServerProto
// Return true if keepalive parameter(s) are enabled.
virtual bool is_keepalive_enabled() const override
{
return ProtoContext::is_keepalive_enabled();
return proto_context.is_keepalive_enabled();
}
// Disable keepalive for rest of session, but fetch
@ -278,7 +275,7 @@ class ServerProto
virtual void disable_keepalive(unsigned int &keepalive_ping,
unsigned int &keepalive_timeout) override
{
ProtoContext::disable_keepalive(keepalive_ping, keepalive_timeout);
proto_context.disable_keepalive(keepalive_ping, keepalive_timeout);
if (ManLink::send)
ManLink::send->keepalive_override(keepalive_ping, keepalive_timeout);
}
@ -286,7 +283,7 @@ class ServerProto
// override the data channel factory
virtual void override_dc_factory(const CryptoDCFactory::Ptr &dc_factory) override
{
ProtoContext::dc_settings().set_factory(dc_factory);
proto_context.dc_settings().set_factory(dc_factory);
}
virtual ~Session()
@ -301,7 +298,7 @@ class ServerProto
const Factory &factory,
ManClientInstance::Factory::Ptr man_factory_arg,
TunClientInstance::Factory::Ptr tun_factory_arg)
: ProtoContext(factory.clone_proto_config(), factory.stats),
: proto_context(this, factory.clone_proto_config(), factory.stats),
housekeeping_timer(io_context_arg),
disconnect_at(Time::infinite()),
stats(factory.stats),
@ -318,11 +315,11 @@ class ServerProto
// proto base class calls here for control channel network sends
virtual void control_net_send(const Buffer &net_buf) override
{
OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport SEND[" << net_buf.size() << "] " << client_endpoint_render() << ' ' << ProtoContext::dump_packet(net_buf));
OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport SEND[" << net_buf.size() << "] " << client_endpoint_render() << ' ' << proto_context.dump_packet(net_buf));
if (TransportLink::send)
{
if (TransportLink::send->transport_send_const(net_buf))
ProtoContext::update_last_sent();
proto_context.update_last_sent();
}
}
@ -353,7 +350,7 @@ class ServerProto
if (msg == "PUSH_REQUEST")
{
if (get_management())
ManLink::send->push_request(ProtoContext::conf_ptr());
ManLink::send->push_request(proto_context.conf_ptr());
else
auth_failed("no management provider", "");
}
@ -368,6 +365,14 @@ class ServerProto
}
}
void active(bool primary) override
{
/* Currently the server does not do anything special when the connection
* is ready (control channel fully established). We probably should trigger
* sending a PUSH_REPLY here, when the client requested it via
* IV_PROTO_REQUEST_PUSH instead waiting for an explicit PUSH_REQUEST */
}
virtual void auth_failed(const std::string &reason,
const std::string &client_reason) override
{
@ -379,7 +384,7 @@ class ServerProto
if (halt || disconnect_type == DT_HALT_RESTART)
return;
ProtoContext::update_now();
proto_context.update_now();
if (TunLink::send && (disconnect_type < DT_RELAY_TRANSITION))
{
@ -388,13 +393,13 @@ class ServerProto
disconnect_in(Time::Duration::seconds(10)); // not a real disconnect, just complete transition to relay
}
if (ProtoContext::primary_defined())
if (proto_context.primary_defined())
{
BufferPtr buf(new BufferAllocated(64, 0));
buf_append_string(*buf, "RELAY");
buf->null_terminate();
ProtoContext::control_send(std::move(buf));
ProtoContext::flush(true);
proto_context.control_send(std::move(buf));
proto_context.flush(true);
}
set_housekeeping_timer();
@ -402,7 +407,7 @@ class ServerProto
virtual void push_reply(std::vector<BufferPtr> &&push_msgs) override
{
if (halt || (disconnect_type >= DT_RELAY_TRANSITION) || !ProtoContext::primary_defined())
if (halt || (disconnect_type >= DT_RELAY_TRANSITION) || !proto_context.primary_defined())
return;
if (disconnect_type == DT_AUTH_PENDING)
@ -411,17 +416,17 @@ class ServerProto
cancel_disconnect();
}
ProtoContext::update_now();
proto_context.update_now();
if (get_tun())
{
ProtoContext::init_data_channel();
proto_context.init_data_channel();
for (auto &msg : push_msgs)
{
msg->null_terminate();
ProtoContext::control_send(std::move(msg));
proto_context.control_send(std::move(msg));
}
ProtoContext::flush(true);
proto_context.flush(true);
set_housekeeping_timer();
}
else
@ -445,7 +450,7 @@ class ServerProto
if (halt || disconnect_type == DT_HALT_RESTART)
return;
ProtoContext::update_now();
proto_context.update_now();
BufferPtr buf(new BufferAllocated(128, BufferAllocated::GROW));
BufferStreamOut os(*buf);
@ -520,11 +525,11 @@ class ServerProto
OPENVPN_LOG(instance_name() << " : Disconnect: " << ts << ' ' << reason);
if (ProtoContext::primary_defined())
if (proto_context.primary_defined())
{
buf->null_terminate();
ProtoContext::control_send(std::move(buf));
ProtoContext::flush(true);
proto_context.control_send(std::move(buf));
proto_context.flush(true);
}
set_housekeeping_timer();
@ -534,7 +539,7 @@ class ServerProto
{
if (halt || disconnect_type == DT_HALT_RESTART)
return;
ProtoContext::update_now();
proto_context.update_now();
disconnect_in(Time::Duration::seconds(seconds));
set_housekeeping_timer();
}
@ -543,7 +548,7 @@ class ServerProto
{
if (halt || (disconnect_type >= DT_RELAY_TRANSITION) || !seconds)
return;
ProtoContext::update_now();
proto_context.update_now();
disconnect_type = DT_AUTH_PENDING;
disconnect_in(Time::Duration::seconds(seconds));
set_housekeeping_timer();
@ -551,41 +556,41 @@ class ServerProto
virtual void post_cc_msg(BufferPtr &&msg) override
{
if (halt || !ProtoContext::primary_defined())
if (halt || !proto_context.primary_defined())
return;
ProtoContext::update_now();
proto_context.update_now();
msg->null_terminate();
ProtoContext::control_send(std::move(msg));
ProtoContext::flush(true);
proto_context.control_send(std::move(msg));
proto_context.flush(true);
set_housekeeping_timer();
}
virtual void stats_notify(const PeerStats &ps, const bool final) override
void stats_notify(const PeerStats &ps, const bool final) override
{
if (ManLink::send)
ManLink::send->stats_notify(ps, final);
}
virtual void float_notify(const PeerAddr::Ptr &addr) override
void float_notify(const PeerAddr::Ptr &addr) override
{
if (ManLink::send)
ManLink::send->float_notify(addr);
}
virtual void ipma_notify(const struct ovpn_tun_head_ipma &ipma) override
void ipma_notify(const struct ovpn_tun_head_ipma &ipma) override
{
if (ManLink::send)
ManLink::send->ipma_notify(ipma);
}
virtual void data_limit_notify(const int key_id,
const DataLimit::Mode cdl_mode,
const DataLimit::State cdl_status) override
void data_limit_notify(const int key_id,
const DataLimit::Mode cdl_mode,
const DataLimit::State cdl_status) override
{
ProtoContext::update_now();
ProtoContext::data_limit_notify(key_id, cdl_mode, cdl_status);
ProtoContext::flush(true);
proto_context.update_now();
proto_context.data_limit_notify(key_id, cdl_mode, cdl_status);
proto_context.flush(true);
set_housekeeping_timer();
}
@ -613,7 +618,7 @@ class ServerProto
// and set_housekeeping_timer() called after this method
void disconnect_in(const Time::Duration &dur)
{
disconnect_at = now() + dur;
disconnect_at = proto_context.now() + dur;
}
void cancel_disconnect()
@ -628,13 +633,13 @@ class ServerProto
if (!e && !halt)
{
// update current time
ProtoContext::update_now();
proto_context.update_now();
housekeeping_schedule.reset();
ProtoContext::housekeeping();
if (ProtoContext::invalidated())
invalidation_error(ProtoContext::invalidation_reason());
else if (now() >= disconnect_at)
proto_context.housekeeping();
if (proto_context.invalidated())
invalidation_error(proto_context.invalidation_reason());
else if (proto_context.now() >= disconnect_at)
{
switch (disconnect_type)
{
@ -642,7 +647,7 @@ class ServerProto
error("disconnect triggered");
break;
case DT_RELAY_TRANSITION:
ProtoContext::pre_destroy();
proto_context.pre_destroy();
break;
case DT_AUTH_PENDING:
auth_failed("Auth Pending Timeout", "Auth Pending Timeout");
@ -664,13 +669,13 @@ class ServerProto
void set_housekeeping_timer()
{
Time next = ProtoContext::next_housekeeping();
Time next = proto_context.next_housekeeping();
next.min(disconnect_at);
if (!housekeeping_schedule.similar(next))
{
if (!next.is_infinite())
{
next.max(now());
next.max(proto_context.now());
housekeeping_schedule.reset(next);
housekeeping_timer.expires_at(next);
housekeeping_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error)
@ -738,6 +743,8 @@ class ServerProto
DT_RELAY_TRANSITION,
DT_HALT_RESTART,
};
ProtoContext proto_context;
int disconnect_type = DT_NONE;
bool preserve_session_id = true;

View File

@ -168,6 +168,53 @@ enum
} // namespace
} // namespace proto_context_private
class ProtoContextCallbackInterface
{
public:
/**
* Sends out bytes to the network.
*/
virtual void control_net_send(const Buffer &net_buf) = 0;
/*
* Receive as packet from the network
* \note app may take ownership of app_bp via std::move
*/
virtual void control_recv(BufferPtr &&app_bp) = 0;
/** Called on client to request username/password credentials.
* Should be overridden by derived class if credentials are required.
* username and password should be written into buf with write_auth_string().
*/
virtual void client_auth(Buffer &buf)
{
write_empty_string(buf); // username
write_empty_string(buf); // password
}
/** Called on server with credentials and peer info provided by client.
*Should be overriden by derived class if credentials are required. */
virtual void server_auth(const std::string &username,
const SafeString &password,
const std::string &peer_info,
const AuthCert::Ptr &auth_cert)
{
}
/**
* Writes an empty user or password string for the key-method 2 packet in the OpenVPN protocol
* @param buf buffer to write to
*/
static void write_empty_string(Buffer &buf)
{
uint8_t empty[]{0x00, 0x00}; // empty length field without content
buf.write(&empty, 2);
}
//! Called when KeyContext transitions to ACTIVE state
virtual void active(bool primary) = 0;
};
class ProtoContext
{
protected:
@ -1388,9 +1435,8 @@ class ProtoContext
return out.str();
}
protected:
// used for reading/writing authentication strings (username, password, etc.)
// used for reading/writing authentication strings (username, password, etc.) from buffer using the
// 2 byte prefix for length
static void write_uint16_length(const size_t size, Buffer &buf)
{
if (size > 0xFFFF)
@ -1446,6 +1492,11 @@ class ProtoContext
buf.null_terminate();
}
static void write_empty_string(Buffer &buf)
{
write_uint16_length(0, buf);
}
template <typename S>
static S read_control_string(const Buffer &buf)
{
@ -1469,17 +1520,7 @@ class ProtoContext
control_send(std::move(bp));
}
static unsigned char *skip_string(Buffer &buf)
{
const size_t len = read_uint16_length(buf);
return buf.read_alloc(len);
}
static void write_empty_string(Buffer &buf)
{
write_uint16_length(0, buf);
}
protected:
// Packet structure for managing network packets, passed as a template
// parameter to ProtoStackBase
class Packet
@ -2785,7 +2826,7 @@ class ProtoContext
const std::string username = read_auth_string<std::string>(*buf);
const SafeString password = read_auth_string<SafeString>(*buf);
const std::string peer_info = read_auth_string<std::string>(*buf);
proto.server_auth(username, password, peer_info, Base::auth_cert());
proto.proto_callback->server_auth(username, password, peer_info, Base::auth_cert());
}
}
@ -3641,9 +3682,11 @@ class ProtoContext
OPENVPN_SIMPLE_EXCEPTION(select_key_context_error);
ProtoContext(const ProtoConfig::Ptr &config_arg, // configuration
ProtoContext(ProtoContextCallbackInterface *cb_arg,
const ProtoConfig::Ptr &config_arg, // configuration
const SessionStats::Ptr &stats_arg) // error stats
: config(config_arg),
: proto_callback(cb_arg),
config(config_arg),
stats(stats_arg),
mode_(config_arg->ssl_factory->mode()),
n_key_ids(0),
@ -4263,8 +4306,13 @@ class ProtoContext
return *stats;
}
protected:
// debugging
bool is_state_client_wait_reset_ack() const
{
return primary_state() == C_WAIT_RESET_ACK;
}
protected:
int primary_state() const
{
if (primary)
@ -4291,32 +4339,11 @@ class ProtoContext
secondary.reset();
}
virtual void control_net_send(const Buffer &net_buf) = 0;
// app may take ownership of app_bp via std::move
virtual void control_recv(BufferPtr &&app_bp) = 0;
// Called on client to request username/password credentials.
// Should be overriden by derived class if credentials are required.
// username and password should be written into buf with write_auth_string().
virtual void client_auth(Buffer &buf)
{
write_empty_string(buf); // username
write_empty_string(buf); // password
}
// Called on server with credentials and peer info provided by client.
// Should be overriden by derived class if credentials are required.
virtual void server_auth(const std::string &username,
const SafeString &password,
const std::string &peer_info,
const AuthCert::Ptr &auth_cert)
{
}
// Called when KeyContext transitions to ACTIVE state
virtual void active(bool primary)
// delegated to the callback/parent
void client_auth(Buffer &buf)
{
proto_callback->client_auth(buf);
}
void update_last_received()
@ -4326,12 +4353,12 @@ class ProtoContext
void net_send(const unsigned int key_id, const Packet &net_pkt)
{
control_net_send(net_pkt.buffer());
proto_callback->control_net_send(net_pkt.buffer());
}
void app_recv(const unsigned int key_id, BufferPtr &&to_app_buf)
{
control_recv(std::move(to_app_buf));
proto_callback->control_recv(std::move(to_app_buf));
}
// we're getting a request from peer to renegotiate.
@ -4471,7 +4498,7 @@ class ProtoContext
case KeyContext::KEV_ACTIVE:
OPENVPN_LOG_PROTO_VERBOSE(debug_prefix() << " SESSION_ACTIVE");
primary->rekey(CryptoDCInstance::ACTIVATE_PRIMARY);
active(true);
proto_callback->active(true);
break;
case KeyContext::KEV_RENEGOTIATE:
case KeyContext::KEV_RENEGOTIATE_FORCE:
@ -4511,7 +4538,7 @@ class ProtoContext
secondary->rekey(CryptoDCInstance::NEW_SECONDARY);
if (primary)
primary->prepare_expire();
active(false);
proto_callback->active(false);
break;
case KeyContext::KEV_BECOME_PRIMARY:
if (!secondary->invalidated())
@ -4590,6 +4617,13 @@ class ProtoContext
// BEGIN ProtoContext data members
/** the class that uses this class needs to be called back on a few things. Typically a class
* that uses this class as field for composition. This parent/callback class needs to ensure that
* it lives longer than this class, e.g. by having this class as field as this class blindly
* assumes that this pointer is always valid for its lifetime
*/
ProtoContextCallbackInterface *proto_callback;
ProtoConfig::Ptr config;
SessionStats::Ptr stats;

View File

@ -353,24 +353,20 @@ class DroughtMeasure
};
// test the OpenVPN protocol implementation in ProtoContext
class TestProto : public ProtoContext
class TestProto : public ProtoContextCallbackInterface
{
typedef ProtoContext Base;
/* Callback methods that are not used */
void active(bool primary) override
{
}
using Base::is_server;
using Base::mode;
using Base::now;
public:
using Base::flush;
typedef Base::PacketType PacketType;
OPENVPN_EXCEPTION(session_invalidated);
TestProto(const Base::ProtoConfig::Ptr &config,
TestProto(const ProtoContext::ProtoConfig::Ptr &config,
const SessionStats::Ptr &stats)
: Base(config, stats),
: proto_context(this, config, stats),
control_drought("control", config->now),
data_drought("data", config->now),
frame(config->frame)
@ -382,27 +378,26 @@ class TestProto : public ProtoContext
void reset()
{
net_out.clear();
Base::reset();
Base::conf().mss_parms.mssfix = MSSParms::MSSFIX_DEFAULT;
proto_context.reset();
proto_context.conf().mss_parms.mssfix = MSSParms::MSSFIX_DEFAULT;
}
void initial_app_send(const char *msg)
{
Base::start();
proto_context.start();
const size_t msglen = std::strlen(msg) + 1;
BufferAllocated app_buf((unsigned char *)msg, msglen, 0);
copy_progress(app_buf);
control_send(std::move(app_buf));
flush(true);
proto_context.flush(true);
}
void app_send_templ_init(const char *msg)
{
Base::start();
proto_context.start();
const size_t msglen = std::strlen(msg) + 1;
templ.reset(new BufferAllocated((unsigned char *)msg, msglen, 0));
flush(true);
proto_context.flush(true);
}
void app_send_templ()
@ -421,9 +416,9 @@ class TestProto : public ProtoContext
bool do_housekeeping()
{
if (now() >= Base::next_housekeeping())
if (proto_context.now() >= proto_context.next_housekeeping())
{
Base::housekeeping();
proto_context.housekeeping();
return true;
}
else
@ -433,13 +428,13 @@ class TestProto : public ProtoContext
void control_send(BufferPtr &&app_bp)
{
app_bytes_ += app_bp->size();
Base::control_send(std::move(app_bp));
proto_context.control_send(std::move(app_bp));
}
void control_send(BufferAllocated &&app_buf)
{
app_bytes_ += app_buf.size();
Base::control_send(std::move(app_buf));
proto_context.control_send(std::move(app_buf));
}
BufferPtr data_encrypt_string(const char *str)
@ -453,12 +448,12 @@ class TestProto : public ProtoContext
void data_encrypt(BufferAllocated &in_out)
{
Base::data_encrypt(in_out);
proto_context.data_encrypt(in_out);
}
void data_decrypt(const PacketType &type, BufferAllocated &in_out)
void data_decrypt(const ProtoContext::PacketType &type, BufferAllocated &in_out)
{
Base::data_decrypt(type, in_out);
proto_context.data_decrypt(type, in_out);
if (in_out.size())
{
data_bytes_ += in_out.size();
@ -500,13 +495,8 @@ class TestProto : public ProtoContext
void check_invalidated()
{
if (Base::invalidated())
throw session_invalidated(Error::name(Base::invalidation_reason()));
}
bool is_state_client_wait_reset_ack() const
{
return primary_state() == C_WAIT_RESET_ACK;
if (proto_context.invalidated())
throw session_invalidated(Error::name(proto_context.invalidation_reason()));
}
void disable_xmit()
@ -514,13 +504,15 @@ class TestProto : public ProtoContext
disable_xmit_ = true;
}
ProtoContext proto_context;
std::deque<BufferPtr> net_out;
DroughtMeasure control_drought;
DroughtMeasure data_drought;
private:
virtual void control_net_send(const Buffer &net_buf)
void control_net_send(const Buffer &net_buf) override
{
if (disable_xmit_)
return;
@ -528,7 +520,7 @@ class TestProto : public ProtoContext
net_out.push_back(BufferPtr(new BufferAllocated(net_buf, 0)));
}
virtual void control_recv(BufferPtr &&app_bp)
void control_recv(BufferPtr &&app_bp) override
{
BufferPtr work;
work.swap(app_bp);
@ -559,7 +551,7 @@ class TestProto : public ProtoContext
void modmsg(BufferPtr &buf)
{
char *msg = (char *)buf->data();
if (is_server())
if (proto_context.is_server())
{
msg[8] = 'S';
msg[11] = 'C';
@ -602,9 +594,9 @@ class TestProtoClient : public TestProto
typedef TestProto Base;
public:
TestProtoClient(const Base::ProtoConfig::Ptr &config,
TestProtoClient(const ProtoContext::ProtoConfig::Ptr &config,
const SessionStats::Ptr &stats)
: Base(config, stats)
: TestProto(config, stats)
{
}
@ -613,21 +605,26 @@ class TestProtoClient : public TestProto
{
const std::string username("foo");
const std::string password("bar");
Base::write_auth_string(username, buf);
Base::write_auth_string(password, buf);
ProtoContext::write_auth_string(username, buf);
ProtoContext::write_auth_string(password, buf);
}
};
class TestProtoServer : public TestProto
{
typedef TestProto Base;
public:
void start()
{
proto_context.start();
}
OPENVPN_SIMPLE_EXCEPTION(auth_failed);
TestProtoServer(const Base::ProtoConfig::Ptr &config,
TestProtoServer(const ProtoContext::ProtoConfig::Ptr &config,
const SessionStats::Ptr &stats)
: Base(config, stats)
: TestProto(config, stats)
{
}
@ -687,7 +684,7 @@ class NoisyWire
a.app_send_templ();
// queue a data channel packet
if (a.data_channel_ready())
if (a.proto_context.data_channel_ready())
{
BufferPtr bp = a.data_encrypt_string("Waiting for godot A... Waiting for godot B... Waiting for godot C... Waiting for godot D... Waiting for godot E... Waiting for godot F... Waiting for godot G... Waiting for godot H... Waiting for godot I... Waiting for godot J...");
wire.push_back(bp);
@ -710,14 +707,14 @@ class NoisyWire
BufferPtr bp = recv();
if (!bp)
break;
typename T2::PacketType pt = b.packet_type(*bp);
typename ProtoContext::PacketType pt = b.proto_context.packet_type(*bp);
if (pt.is_control())
{
#ifdef VERBOSE
if (!b.control_net_validate(pt, *bp)) // not strictly necessary since control_net_recv will also validate
std::cout << now->raw() << " " << title << " CONTROL PACKET VALIDATION FAILED" << std::endl;
#endif
b.control_net_recv(pt, std::move(bp));
b.proto_context.control_net_recv(pt, std::move(bp));
}
else if (pt.is_data())
{
@ -744,11 +741,11 @@ class NoisyWire
#ifdef VERBOSE
std::cout << now->raw() << " " << title << " KEY_STATE_ERROR" << std::endl;
#endif
b.stat().error(Error::KEY_STATE_ERROR);
b.proto_context.stat().error(Error::KEY_STATE_ERROR);
}
#ifdef SIMULATE_UDP_AMPLIFY_ATTACK
if (b.is_state_client_wait_reset_ack())
if (b.proto_context.is_state_client_wait_reset_ack())
{
b.disable_xmit();
#ifdef VERBOSE
@ -757,7 +754,7 @@ class NoisyWire
}
#endif
}
b.flush(true);
b.proto_context.flush(true);
}
private:
@ -1148,8 +1145,8 @@ int test(const int thread_num)
<< " CTRL=" << cli_proto.n_control_recv() << '/' << cli_proto.n_control_send() << '/' << serv_proto.n_control_recv() << '/' << serv_proto.n_control_send()
#endif
<< " D=" << cli_proto.control_drought().raw() << '/' << cli_proto.data_drought().raw() << '/' << serv_proto.control_drought().raw() << '/' << serv_proto.data_drought().raw()
<< " N=" << cli_proto.negotiations() << '/' << serv_proto.negotiations()
<< " SH=" << cli_proto.slowest_handshake().raw() << '/' << serv_proto.slowest_handshake().raw()
<< " N=" << cli_proto.proto_context.negotiations() << '/' << serv_proto.proto_context.negotiations()
<< " SH=" << cli_proto.proto_context.slowest_handshake().raw() << '/' << serv_proto.proto_context.slowest_handshake().raw()
<< " HE=" << cli_stats->get_error_count(Error::HANDSHAKE_TIMEOUT) << '/' << serv_stats->get_error_count(Error::HANDSHAKE_TIMEOUT)
<< std::endl;