From e14c3f0441163c695ef8524032e6fc695edc863c Mon Sep 17 00:00:00 2001 From: Arne Schwabe Date: Thu, 22 Feb 2024 07:26:46 +0100 Subject: [PATCH] Refactor ClientProto::Session to use ProtoContext as field insatead of Base Currently the protocontext is used as kind of composition but not really and makes following the code harder, since this inheritance not only serves for composition but also as callbacks through virtual method inheritance. Making ProtoContext a normal field and definining a callback interface makes the class relationship easier to understand. Signed-off-by: Arne Schwabe --- openvpn/client/cliopt.hpp | 20 ++--- openvpn/client/cliproto.hpp | 149 ++++++++++++++++------------------ openvpn/server/servproto.hpp | 139 ++++++++++++++++--------------- openvpn/ssl/proto.hpp | 126 +++++++++++++++++----------- test/unittests/test_proto.cpp | 97 +++++++++++----------- 5 files changed, 282 insertions(+), 249 deletions(-) diff --git a/openvpn/client/cliopt.hpp b/openvpn/client/cliopt.hpp index 4648de92..fe87a2ef 100644 --- a/openvpn/client/cliopt.hpp +++ b/openvpn/client/cliopt.hpp @@ -1164,7 +1164,7 @@ class ClientOptions : public RC // Copy ProtoConfig so that modifications due to server push will // not persist across client instantiations. - cli_config->proto_context_config.reset(new Client::ProtoConfig(proto_config_cached(relay_mode))); + cli_config->proto_context_config.reset(new ProtoContext::ProtoConfig(proto_config_cached(relay_mode))); cli_config->proto_context_options = proto_context_options; cli_config->push_base = push_base; @@ -1274,7 +1274,7 @@ class ClientOptions : public RC } private: - Client::ProtoConfig &proto_config_cached(const bool relay_mode) + ProtoContext::ProtoConfig &proto_config_cached(const bool relay_mode) { if (relay_mode && cp_relay) return *cp_relay; @@ -1282,14 +1282,14 @@ class ClientOptions : public RC return *cp_main; } - Client::ProtoConfig::Ptr proto_config(const OptionList &opt, - const Config &config, - const ParseClientConfig &pcc, - const bool relay_mode) + ProtoContext::ProtoConfig::Ptr proto_config(const OptionList &opt, + const Config &config, + const ParseClientConfig &pcc, + const bool relay_mode) { // relay mode is null unless one of the below directives is defined if (relay_mode && !opt.exists("relay-mode")) - return Client::ProtoConfig::Ptr(); + return ProtoContext::ProtoConfig::Ptr(); // load flags unsigned int lflags = SSLConfigAPI::LF_PARSE_MODE; @@ -1314,7 +1314,7 @@ class ClientOptions : public RC cc->set_tls_ciphersuite_list(config.clientconf.tlsCiphersuitesList); // client ProtoContext config - Client::ProtoConfig::Ptr cp(new Client::ProtoConfig()); + ProtoContext::ProtoConfig::Ptr cp(new ProtoContext::ProtoConfig()); cp->ssl_factory = cc->new_factory(); cp->relay_mode = relay_mode; cp->dc.set_factory(new CryptoDCSelect(cp->ssl_factory->libctx(), frame, cli_stats, rng)); @@ -1454,8 +1454,8 @@ class ClientOptions : public RC RandomAPI::Ptr prng; Frame::Ptr frame; Layer layer; - Client::ProtoConfig::Ptr cp_main; - Client::ProtoConfig::Ptr cp_relay; + ProtoContext::ProtoConfig::Ptr cp_main; + ProtoContext::ProtoConfig::Ptr cp_relay; RemoteList::Ptr remote_list; bool server_addr_float; TransportClientFactory::Ptr transport_factory; diff --git a/openvpn/client/cliproto.hpp b/openvpn/client/cliproto.hpp index ccace8d5..2c99a6e8 100644 --- a/openvpn/client/cliproto.hpp +++ b/openvpn/client/cliproto.hpp @@ -92,20 +92,13 @@ struct NotifyCallback } }; -class Session : ProtoContext, +class Session : ProtoContextCallbackInterface, TransportClientParent, TunClientParent, public RC { - typedef ProtoContext Base; - typedef Base::PacketType PacketType; - - using Base::now; - using Base::stat; - public: typedef RCPtr Ptr; - typedef Base::ProtoConfig ProtoConfig; OPENVPN_EXCEPTION(client_exception); OPENVPN_EXCEPTION(client_halt_restart); @@ -133,7 +126,7 @@ class Session : ProtoContext, { } - ProtoConfig::Ptr proto_context_config; + ProtoContext::ProtoConfig::Ptr proto_context_config; ProtoContextCompressionOptions::Ptr proto_context_options; PushOptionsBase::Ptr push_base; TransportClientFactory::Ptr transport_factory; @@ -152,7 +145,7 @@ class Session : ProtoContext, Session(openvpn_io::io_context &io_context_arg, const Config &config, NotifyCallback *notify_callback_arg) - : Base(config.proto_context_config, config.cli_stats), + : proto_context(this, config.proto_context_config, config.cli_stats), io_context(io_context_arg), transport_factory(config.transport_factory), tun_factory(config.tun_factory), @@ -178,9 +171,9 @@ class Session : ProtoContext, if (!packet_log) OPENVPN_THROW(open_file_error, "cannot open packet log for output: " << OPENVPN_PACKET_LOG); #endif - Base::update_now(); - Base::reset(); - // Base::enable_strict_openvpn_2x(); + proto_context.update_now(); + proto_context.reset(); + // proto_context.enable_strict_openvpn_2x(); info_hold.reset(new std::vector()); } @@ -194,7 +187,7 @@ class Session : ProtoContext, { if (!halt) { - Base::update_now(); + proto_context.update_now(); // coarse wakeup range housekeeping_schedule.init(Time::Duration::binary_ms(512), Time::Duration::binary_ms(1024)); @@ -224,7 +217,7 @@ class Session : ProtoContext, void send_explicit_exit_notify() { if (!halt) - Base::send_explicit_exit_notify(); + proto_context.send_explicit_exit_notify(); } void tun_set_disconnect() @@ -235,22 +228,22 @@ class Session : ProtoContext, void post_cc_msg(const std::string &msg) { - Base::update_now(); - Base::write_control_string(msg); - Base::flush(true); + proto_context.update_now(); + proto_context.write_control_string(msg); + proto_context.flush(true); set_housekeeping_timer(); } void post_app_control_message(const std::string proto, const std::string message) { - if (!conf().app_control_config.supports_protocol(proto)) + if (!proto_context.conf().app_control_config.supports_protocol(proto)) { ClientEvent::Base::Ptr ev = new ClientEvent::UnsupportedFeature{"missing acc protocol support", "server has not announced support of this custom app control protocol", false}; cli_events->add_event(std::move(ev)); return; } - for (auto fragment : conf().app_control_config.format_message(proto, message)) + for (auto fragment : proto_context.conf().app_control_config.format_message(proto, message)) post_cc_msg(std::move(fragment)); } @@ -328,13 +321,13 @@ class Session : ProtoContext, { try { - OPENVPN_LOG_CLIPROTO("Transport RECV " << server_endpoint_render() << ' ' << Base::dump_packet(buf)); + OPENVPN_LOG_CLIPROTO("Transport RECV " << server_endpoint_render() << ' ' << proto_context.dump_packet(buf)); // update current time - Base::update_now(); + proto_context.update_now(); // update last packet received - stat().update_last_packet_received(now()); + proto_context.stat().update_last_packet_received(proto_context.now()); // log connecting event (only on first packet received) if (!first_packet_received_) @@ -345,13 +338,13 @@ class Session : ProtoContext, } // get packet type - Base::PacketType pt = Base::packet_type(buf); + ProtoContext::PacketType pt = proto_context.packet_type(buf); // process packet if (pt.is_data()) { // data packet - Base::data_decrypt(pt, buf); + proto_context.data_decrypt(pt, buf); if (buf.size()) { #ifdef OPENVPN_PACKET_LOG @@ -366,15 +359,15 @@ class Session : ProtoContext, } // do a lightweight flush - Base::flush(false); + proto_context.flush(false); } else if (pt.is_control()) { // control packet - Base::control_net_recv(pt, std::move(buf)); + proto_context.control_net_recv(pt, std::move(buf)); // do a full flush - Base::flush(true); + proto_context.flush(true); } else cli_stats->error(Error::KEY_STATE_ERROR); @@ -412,7 +405,7 @@ class Session : ProtoContext, OPENVPN_LOG_CLIPROTO("TUN recv, size=" << buf.size()); // update current time - Base::update_now(); + proto_context.update_now(); // log packet #ifdef OPENVPN_PACKET_LOG @@ -432,7 +425,7 @@ class Session : ProtoContext, // encrypt packet if (buf.size()) { - const ProtoContext::ProtoConfig &c = Base::conf(); + const ProtoContext::ProtoConfig &c = proto_context.conf(); // when calculating mss, we take IPv4 and TCP headers into account // here we need to add it back since we check the whole IP packet size, not just TCP payload constexpr size_t MinTcpHeader = 20; @@ -445,13 +438,13 @@ class Session : ProtoContext, } else { - Base::data_encrypt(buf); + proto_context.data_encrypt(buf); if (buf.size()) { // send packet via transport to destination - OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << Base::dump_packet(buf)); + OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << proto_context.dump_packet(buf)); if (transport->transport_send(buf)) - Base::update_last_sent(); + proto_context.update_last_sent(); else if (halt) return; } @@ -459,7 +452,7 @@ class Session : ProtoContext, } // do a lightweight flush - Base::flush(false); + proto_context.flush(false); // schedule housekeeping wakeup set_housekeeping_timer(); @@ -473,7 +466,7 @@ class Session : ProtoContext, // Return true if keepalive parameter(s) are enabled. bool is_keepalive_enabled() const override { - return Base::is_keepalive_enabled(); + return proto_context.is_keepalive_enabled(); } // Disable keepalive for rest of session, but fetch @@ -481,7 +474,7 @@ class Session : ProtoContext, void disable_keepalive(unsigned int &keepalive_ping, unsigned int &keepalive_timeout) override { - Base::disable_keepalive(keepalive_ping, keepalive_timeout); + proto_context.disable_keepalive(keepalive_ping, keepalive_timeout); } void transport_pre_resolve() override @@ -516,9 +509,9 @@ class Session : ProtoContext, try { OPENVPN_LOG("Connecting to " << server_endpoint_render()); - Base::set_protocol(transport->transport_protocol()); - Base::start(); - Base::flush(true); + proto_context.set_protocol(transport->transport_protocol()); + proto_context.start(); + proto_context.flush(true); set_housekeeping_timer(); } catch (const std::exception &e) @@ -587,7 +580,7 @@ class Session : ProtoContext, OPENVPN_LOG("Session token: [redacted]"); #endif autologin_sessions = true; - conf().set_xmit_creds(true); + proto_context.conf().set_xmit_creds(true); creds->set_replace_password_with_session_id(true); creds->set_session_id(username, sess_id); } @@ -688,9 +681,9 @@ class Session : ProtoContext, // proto base class calls here for control channel network sends void control_net_send(const Buffer &net_buf) override { - OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << Base::dump_packet(net_buf)); + OPENVPN_LOG_CLIPROTO("Transport SEND " << server_endpoint_render() << ' ' << proto_context.dump_packet(net_buf)); if (transport->transport_send_const(net_buf)) - Base::update_last_sent(); + proto_context.update_last_sent(); } void recv_auth_failed(const std::string &msg) @@ -753,7 +746,7 @@ class Session : ProtoContext, { timeout = clamp_to_typerange(std::stoul(timeout_str)); // Cap the timeout to end well before renegotiation starts - timeout = std::min(timeout, static_cast(conf().renegotiate.to_seconds() / 2)); + timeout = std::min(timeout, static_cast(proto_context.conf().renegotiate.to_seconds() / 2)); } catch (const std::logic_error &) { @@ -776,7 +769,7 @@ class Session : ProtoContext, void recv_relay() { - if (conf().relay_mode) + if (proto_context.conf().relay_mode) { fatal_ = Error::RELAY; fatal_reason_ = ""; @@ -818,7 +811,7 @@ class Session : ProtoContext, // proto base class calls here for app-level control-channel messages received void control_recv(BufferPtr &&app_bp) override { - const std::string msg = Unicode::utf8_printable(Base::template read_control_string(*app_bp), + const std::string msg = Unicode::utf8_printable(ProtoContext::template read_control_string(*app_bp), Unicode::UTF8_FILTER | Unicode::UTF8_PASS_FMT); // OPENVPN_LOG("SERVER: " << sanitize_control_message(msg)); @@ -859,13 +852,13 @@ class Session : ProtoContext, void recv_custom_control_message(const std::string msg) { - bool fullmessage = conf().app_control_recv.receive_message(msg); + bool fullmessage = proto_context.conf().app_control_recv.receive_message(msg); if (!fullmessage) return; - auto [proto, app_proto_msg] = conf().app_control_recv.get_message(); + auto [proto, app_proto_msg] = proto_context.conf().app_control_recv.get_message(); - if (conf().app_control_config.supports_protocol(proto)) + if (proto_context.conf().app_control_config.supports_protocol(proto)) { auto ev = new ClientEvent::AppCustomControlMessage(std::move(proto), std::move(app_proto_msg)); cli_events->add_event(std::move(ev)); @@ -897,7 +890,7 @@ class Session : ProtoContext, << render_options_sanitized(received_options, Option::RENDER_PASS_FMT | Option::RENDER_NUMBER | Option::RENDER_BRACKET)); // relay servers are not allowed to establish a tunnel with us - if (Base::conf().relay_mode) + if (proto_context.conf().relay_mode) { tun_error(Error::RELAY_ERROR, "tunnel not permitted to relay server"); return; @@ -917,28 +910,28 @@ class Session : ProtoContext, transport_factory->process_push(received_options); // modify proto config (cipher, auth, key-derivation and compression methods) - Base::process_push(received_options, *proto_context_options); + proto_context.process_push(received_options, *proto_context_options); // initialize tun/routing tun = tun_factory->new_tun_client_obj(io_context, *this, transport.get()); - tun->tun_start(received_options, *transport, Base::dc_settings()); + tun->tun_start(received_options, *transport, proto_context.dc_settings()); // we should be connected at this point if (!connected_) throw tun_exception("not connected"); // Propagate tun-mtu back, it might have been overwritten by a pushed tun-mtu option - conf().tun_mtu = tun->vpn_mtu(); + proto_context.conf().tun_mtu = tun->vpn_mtu(); // initialize data channel after pushed options have been processed - Base::init_data_channel(); + proto_context.init_data_channel(); // we got pushed options and initializated crypto - now we can push mss to dco - tun->adjust_mss(conf().mss_fix); + tun->adjust_mss(proto_context.conf().mss_fix); // Allow ProtoContext to suggest an alignment adjustment // hint for transport layer. - transport->reset_align_adjust(Base::align_adjust_hint()); + transport->reset_align_adjust(proto_context.align_adjust_hint()); // process "inactive" directive process_inactive(received_options); @@ -954,10 +947,10 @@ class Session : ProtoContext, cli_events->add_event(connected_); // send an event for custom app control if present - if (!conf().app_control_config.supported_protocols.empty()) + if (!proto_context.conf().app_control_config.supported_protocols.empty()) { // Signal support for supported protocols - auto ev = new ClientEvent::AppCustomControlMessage("internal:supported_protocols", string::join(conf().app_control_config.supported_protocols, ":")); + auto ev = new ClientEvent::AppCustomControlMessage("internal:supported_protocols", string::join(proto_context.conf().app_control_config.supported_protocols, ":")); cli_events->add_event(std::move(ev)); } @@ -1050,27 +1043,27 @@ class Session : ProtoContext, void client_auth(Buffer &buf) override { // we never send creds to a relay server - if (creds && !Base::conf().relay_mode) + if (creds && !proto_context.conf().relay_mode) { OPENVPN_LOG("Creds: " << creds->auth_info()); - Base::write_auth_string(creds->get_username(), buf); + proto_context.write_auth_string(creds->get_username(), buf); #ifdef OPENVPN_DISABLE_AUTH_TOKEN // debugging only if (creds->session_id_defined()) { OPENVPN_LOG("NOTE: not sending auth-token"); - Base::write_empty_string(buf); + ProtoContext::write_empty_string(buf); } else #endif { - Base::write_auth_string(creds->get_password(), buf); + proto_context.write_auth_string(creds->get_password(), buf); } } else { OPENVPN_LOG("Creds: None"); - Base::write_empty_string(buf); // username - Base::write_empty_string(buf); // password + write_empty_string(buf); // username + write_empty_string(buf); // password } } @@ -1081,7 +1074,7 @@ class Session : ProtoContext, { if (!e && !halt && !received_options.partial()) { - Base::update_now(); + proto_context.update_now(); if (!sent_push_request) { ClientEvent::Base::Ptr ev = new ClientEvent::GetConfig(); @@ -1089,8 +1082,8 @@ class Session : ProtoContext, sent_push_request = true; } OPENVPN_LOG("Sending PUSH_REQUEST to server..."); - Base::write_control_string(std::string("PUSH_REQUEST")); - Base::flush(true); + proto_context.write_control_string(std::string("PUSH_REQUEST")); + proto_context.flush(true); set_housekeeping_timer(); { @@ -1134,7 +1127,7 @@ class Session : ProtoContext, // react to any tls warning triggered during the tls-handshake virtual void check_tls_warnings() { - uint32_t tls_warnings = get_tls_warnings(); + uint32_t tls_warnings = proto_context.get_tls_warnings(); if (tls_warnings & SSLAPI::TLS_WARN_SIG_MD5) { @@ -1151,13 +1144,13 @@ class Session : ProtoContext, void check_proto_warnings() { - if (uses_bs64_cipher()) + if (proto_context.uses_bs64_cipher()) { ClientEvent::Base::Ptr ev = new ClientEvent::Warn("Proto: Using a 64-bit block cipher that is vulnerable to the SWEET32 attack. Please inform your admin to upgrade to a stronger algorithm. Support for 64-bit block cipher will be dropped in the future."); cli_events->add_event(std::move(ev)); } - CompressContext::Type comp_type = Base::conf().comp_ctx.type(); + CompressContext::Type comp_type = proto_context.conf().comp_ctx.type(); // abort connection if compression is pushed and its support is unannounced if (comp_type != CompressContext::COMP_STUBv2 @@ -1203,15 +1196,15 @@ class Session : ProtoContext, if (!e && !halt) { // update current time - Base::update_now(); + proto_context.update_now(); housekeeping_schedule.reset(); - Base::housekeeping(); - if (Base::invalidated()) + proto_context.housekeeping(); + if (proto_context.invalidated()) { if (notify_callback) { - OPENVPN_LOG("Session invalidated: " << Error::name(Base::invalidation_reason())); + OPENVPN_LOG("Session invalidated: " << Error::name(proto_context.invalidation_reason())); stop(true); } else @@ -1231,12 +1224,12 @@ class Session : ProtoContext, if (halt) return; - Time next = Base::next_housekeeping(); + Time next = proto_context.next_housekeeping(); if (!housekeeping_schedule.similar(next)) { if (!next.is_infinite()) { - next.max(now()); + next.max(proto_context.now()); housekeeping_schedule.reset(next); housekeeping_timer.expires_at(next); housekeeping_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error) @@ -1377,7 +1370,7 @@ class Session : ProtoContext, void schedule_info_hold_callback() { - Base::update_now(); + proto_context.update_now(); info_hold_timer.expires_after(Time::Duration::seconds(1)); info_hold_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error) { @@ -1391,7 +1384,7 @@ class Session : ProtoContext, { if (!e && !halt) { - Base::update_now(); + proto_context.update_now(); if (info_hold) { for (auto &ev : *info_hold) @@ -1420,6 +1413,8 @@ class Session : ProtoContext, } #endif + ProtoContext proto_context; + openvpn_io::io_context &io_context; TransportClientFactory::Ptr transport_factory; diff --git a/openvpn/server/servproto.hpp b/openvpn/server/servproto.hpp index 857cfbb1..9447db49 100644 --- a/openvpn/server/servproto.hpp +++ b/openvpn/server/servproto.hpp @@ -109,16 +109,13 @@ class ServerProto }; // This is the main server-side client instance object - class Session : ProtoContext, // OpenVPN protocol implementation - public TransportLink, // Transport layer - public TunLink, // Tun/routing layer - public ManLink // Management layer + class Session : ProtoContextCallbackInterface, // Callback interface from protocol implementation + public TransportLink, // Transport layer + public TunLink, // Tun/routing layer + public ManLink // Management layer { friend class Factory; // calls constructor - using ProtoContext::now; - using ProtoContext::stat; - public: typedef RCPtr Ptr; @@ -142,11 +139,11 @@ class ServerProto peer_addr = addr; // init OpenVPN protocol handshake - ProtoContext::update_now(); - ProtoContext::reset(cookie_psid); - ProtoContext::set_local_peer_id(local_peer_id); - ProtoContext::start(cookie_psid); - ProtoContext::flush(true); + proto_context.update_now(); + proto_context.reset(cookie_psid); + proto_context.set_local_peer_id(local_peer_id); + proto_context.start(cookie_psid); + proto_context.flush(true); // coarse wakeup range housekeeping_schedule.init(Time::Duration::binary_ms(512), Time::Duration::binary_ms(1024)); @@ -182,8 +179,8 @@ class ServerProto ManLink::send->stats_notify(TransportLink::send->stats_poll(), true); } - ProtoContext::pre_destroy(); - ProtoContext::reset_dc_factory(); + proto_context.pre_destroy(); + proto_context.reset_dc_factory(); if (TransportLink::send) { TransportLink::send->stop(); @@ -206,23 +203,23 @@ class ServerProto virtual bool transport_recv(BufferAllocated &buf) override { bool ret = false; - if (!ProtoContext::primary_defined()) + if (!proto_context.primary_defined()) return false; try { - OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport RECV[" << buf.size() << "] " << client_endpoint_render() << ' ' << ProtoContext::dump_packet(buf)); + OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport RECV[" << buf.size() << "] " << client_endpoint_render() << ' ' << proto_context.dump_packet(buf)); // update current time - ProtoContext::update_now(); + proto_context.update_now(); // get packet type - ProtoContext::PacketType pt = ProtoContext::packet_type(buf); + ProtoContext::PacketType pt = proto_context.packet_type(buf); // process packet if (pt.is_data()) { // data packet - ret = ProtoContext::data_decrypt(pt, buf); + ret = proto_context.data_decrypt(pt, buf); if (buf.size()) { #ifdef OPENVPN_PACKET_LOG @@ -237,15 +234,15 @@ class ServerProto } // do a lightweight flush - ProtoContext::flush(false); + proto_context.flush(false); } else if (pt.is_control()) { // control packet - ret = ProtoContext::control_net_recv(pt, std::move(buf)); + ret = proto_context.control_net_recv(pt, std::move(buf)); // do a full flush - ProtoContext::flush(true); + proto_context.flush(true); } // schedule housekeeping wakeup @@ -269,7 +266,7 @@ class ServerProto // Return true if keepalive parameter(s) are enabled. virtual bool is_keepalive_enabled() const override { - return ProtoContext::is_keepalive_enabled(); + return proto_context.is_keepalive_enabled(); } // Disable keepalive for rest of session, but fetch @@ -278,7 +275,7 @@ class ServerProto virtual void disable_keepalive(unsigned int &keepalive_ping, unsigned int &keepalive_timeout) override { - ProtoContext::disable_keepalive(keepalive_ping, keepalive_timeout); + proto_context.disable_keepalive(keepalive_ping, keepalive_timeout); if (ManLink::send) ManLink::send->keepalive_override(keepalive_ping, keepalive_timeout); } @@ -286,7 +283,7 @@ class ServerProto // override the data channel factory virtual void override_dc_factory(const CryptoDCFactory::Ptr &dc_factory) override { - ProtoContext::dc_settings().set_factory(dc_factory); + proto_context.dc_settings().set_factory(dc_factory); } virtual ~Session() @@ -301,7 +298,7 @@ class ServerProto const Factory &factory, ManClientInstance::Factory::Ptr man_factory_arg, TunClientInstance::Factory::Ptr tun_factory_arg) - : ProtoContext(factory.clone_proto_config(), factory.stats), + : proto_context(this, factory.clone_proto_config(), factory.stats), housekeeping_timer(io_context_arg), disconnect_at(Time::infinite()), stats(factory.stats), @@ -318,11 +315,11 @@ class ServerProto // proto base class calls here for control channel network sends virtual void control_net_send(const Buffer &net_buf) override { - OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport SEND[" << net_buf.size() << "] " << client_endpoint_render() << ' ' << ProtoContext::dump_packet(net_buf)); + OPENVPN_LOG_SERVPROTO(instance_name() << " : Transport SEND[" << net_buf.size() << "] " << client_endpoint_render() << ' ' << proto_context.dump_packet(net_buf)); if (TransportLink::send) { if (TransportLink::send->transport_send_const(net_buf)) - ProtoContext::update_last_sent(); + proto_context.update_last_sent(); } } @@ -353,7 +350,7 @@ class ServerProto if (msg == "PUSH_REQUEST") { if (get_management()) - ManLink::send->push_request(ProtoContext::conf_ptr()); + ManLink::send->push_request(proto_context.conf_ptr()); else auth_failed("no management provider", ""); } @@ -368,6 +365,14 @@ class ServerProto } } + void active(bool primary) override + { + /* Currently the server does not do anything special when the connection + * is ready (control channel fully established). We probably should trigger + * sending a PUSH_REPLY here, when the client requested it via + * IV_PROTO_REQUEST_PUSH instead waiting for an explicit PUSH_REQUEST */ + } + virtual void auth_failed(const std::string &reason, const std::string &client_reason) override { @@ -379,7 +384,7 @@ class ServerProto if (halt || disconnect_type == DT_HALT_RESTART) return; - ProtoContext::update_now(); + proto_context.update_now(); if (TunLink::send && (disconnect_type < DT_RELAY_TRANSITION)) { @@ -388,13 +393,13 @@ class ServerProto disconnect_in(Time::Duration::seconds(10)); // not a real disconnect, just complete transition to relay } - if (ProtoContext::primary_defined()) + if (proto_context.primary_defined()) { BufferPtr buf(new BufferAllocated(64, 0)); buf_append_string(*buf, "RELAY"); buf->null_terminate(); - ProtoContext::control_send(std::move(buf)); - ProtoContext::flush(true); + proto_context.control_send(std::move(buf)); + proto_context.flush(true); } set_housekeeping_timer(); @@ -402,7 +407,7 @@ class ServerProto virtual void push_reply(std::vector &&push_msgs) override { - if (halt || (disconnect_type >= DT_RELAY_TRANSITION) || !ProtoContext::primary_defined()) + if (halt || (disconnect_type >= DT_RELAY_TRANSITION) || !proto_context.primary_defined()) return; if (disconnect_type == DT_AUTH_PENDING) @@ -411,17 +416,17 @@ class ServerProto cancel_disconnect(); } - ProtoContext::update_now(); + proto_context.update_now(); if (get_tun()) { - ProtoContext::init_data_channel(); + proto_context.init_data_channel(); for (auto &msg : push_msgs) { msg->null_terminate(); - ProtoContext::control_send(std::move(msg)); + proto_context.control_send(std::move(msg)); } - ProtoContext::flush(true); + proto_context.flush(true); set_housekeeping_timer(); } else @@ -445,7 +450,7 @@ class ServerProto if (halt || disconnect_type == DT_HALT_RESTART) return; - ProtoContext::update_now(); + proto_context.update_now(); BufferPtr buf(new BufferAllocated(128, BufferAllocated::GROW)); BufferStreamOut os(*buf); @@ -520,11 +525,11 @@ class ServerProto OPENVPN_LOG(instance_name() << " : Disconnect: " << ts << ' ' << reason); - if (ProtoContext::primary_defined()) + if (proto_context.primary_defined()) { buf->null_terminate(); - ProtoContext::control_send(std::move(buf)); - ProtoContext::flush(true); + proto_context.control_send(std::move(buf)); + proto_context.flush(true); } set_housekeeping_timer(); @@ -534,7 +539,7 @@ class ServerProto { if (halt || disconnect_type == DT_HALT_RESTART) return; - ProtoContext::update_now(); + proto_context.update_now(); disconnect_in(Time::Duration::seconds(seconds)); set_housekeeping_timer(); } @@ -543,7 +548,7 @@ class ServerProto { if (halt || (disconnect_type >= DT_RELAY_TRANSITION) || !seconds) return; - ProtoContext::update_now(); + proto_context.update_now(); disconnect_type = DT_AUTH_PENDING; disconnect_in(Time::Duration::seconds(seconds)); set_housekeeping_timer(); @@ -551,41 +556,41 @@ class ServerProto virtual void post_cc_msg(BufferPtr &&msg) override { - if (halt || !ProtoContext::primary_defined()) + if (halt || !proto_context.primary_defined()) return; - ProtoContext::update_now(); + proto_context.update_now(); msg->null_terminate(); - ProtoContext::control_send(std::move(msg)); - ProtoContext::flush(true); + proto_context.control_send(std::move(msg)); + proto_context.flush(true); set_housekeeping_timer(); } - virtual void stats_notify(const PeerStats &ps, const bool final) override + void stats_notify(const PeerStats &ps, const bool final) override { if (ManLink::send) ManLink::send->stats_notify(ps, final); } - virtual void float_notify(const PeerAddr::Ptr &addr) override + void float_notify(const PeerAddr::Ptr &addr) override { if (ManLink::send) ManLink::send->float_notify(addr); } - virtual void ipma_notify(const struct ovpn_tun_head_ipma &ipma) override + void ipma_notify(const struct ovpn_tun_head_ipma &ipma) override { if (ManLink::send) ManLink::send->ipma_notify(ipma); } - virtual void data_limit_notify(const int key_id, - const DataLimit::Mode cdl_mode, - const DataLimit::State cdl_status) override + void data_limit_notify(const int key_id, + const DataLimit::Mode cdl_mode, + const DataLimit::State cdl_status) override { - ProtoContext::update_now(); - ProtoContext::data_limit_notify(key_id, cdl_mode, cdl_status); - ProtoContext::flush(true); + proto_context.update_now(); + proto_context.data_limit_notify(key_id, cdl_mode, cdl_status); + proto_context.flush(true); set_housekeeping_timer(); } @@ -613,7 +618,7 @@ class ServerProto // and set_housekeeping_timer() called after this method void disconnect_in(const Time::Duration &dur) { - disconnect_at = now() + dur; + disconnect_at = proto_context.now() + dur; } void cancel_disconnect() @@ -628,13 +633,13 @@ class ServerProto if (!e && !halt) { // update current time - ProtoContext::update_now(); + proto_context.update_now(); housekeeping_schedule.reset(); - ProtoContext::housekeeping(); - if (ProtoContext::invalidated()) - invalidation_error(ProtoContext::invalidation_reason()); - else if (now() >= disconnect_at) + proto_context.housekeeping(); + if (proto_context.invalidated()) + invalidation_error(proto_context.invalidation_reason()); + else if (proto_context.now() >= disconnect_at) { switch (disconnect_type) { @@ -642,7 +647,7 @@ class ServerProto error("disconnect triggered"); break; case DT_RELAY_TRANSITION: - ProtoContext::pre_destroy(); + proto_context.pre_destroy(); break; case DT_AUTH_PENDING: auth_failed("Auth Pending Timeout", "Auth Pending Timeout"); @@ -664,13 +669,13 @@ class ServerProto void set_housekeeping_timer() { - Time next = ProtoContext::next_housekeeping(); + Time next = proto_context.next_housekeeping(); next.min(disconnect_at); if (!housekeeping_schedule.similar(next)) { if (!next.is_infinite()) { - next.max(now()); + next.max(proto_context.now()); housekeeping_schedule.reset(next); housekeeping_timer.expires_at(next); housekeeping_timer.async_wait([self = Ptr(this)](const openvpn_io::error_code &error) @@ -738,6 +743,8 @@ class ServerProto DT_RELAY_TRANSITION, DT_HALT_RESTART, }; + + ProtoContext proto_context; int disconnect_type = DT_NONE; bool preserve_session_id = true; diff --git a/openvpn/ssl/proto.hpp b/openvpn/ssl/proto.hpp index d9d2826b..4d8f89ea 100644 --- a/openvpn/ssl/proto.hpp +++ b/openvpn/ssl/proto.hpp @@ -168,6 +168,53 @@ enum } // namespace } // namespace proto_context_private +class ProtoContextCallbackInterface +{ + public: + /** + * Sends out bytes to the network. + */ + virtual void control_net_send(const Buffer &net_buf) = 0; + + /* + * Receive as packet from the network + * \note app may take ownership of app_bp via std::move + */ + virtual void control_recv(BufferPtr &&app_bp) = 0; + + /** Called on client to request username/password credentials. + * Should be overridden by derived class if credentials are required. + * username and password should be written into buf with write_auth_string(). + */ + virtual void client_auth(Buffer &buf) + { + write_empty_string(buf); // username + write_empty_string(buf); // password + } + + /** Called on server with credentials and peer info provided by client. + *Should be overriden by derived class if credentials are required. */ + virtual void server_auth(const std::string &username, + const SafeString &password, + const std::string &peer_info, + const AuthCert::Ptr &auth_cert) + { + } + + /** + * Writes an empty user or password string for the key-method 2 packet in the OpenVPN protocol + * @param buf buffer to write to + */ + static void write_empty_string(Buffer &buf) + { + uint8_t empty[]{0x00, 0x00}; // empty length field without content + buf.write(&empty, 2); + } + + //! Called when KeyContext transitions to ACTIVE state + virtual void active(bool primary) = 0; +}; + class ProtoContext { protected: @@ -1388,9 +1435,8 @@ class ProtoContext return out.str(); } - protected: - // used for reading/writing authentication strings (username, password, etc.) - + // used for reading/writing authentication strings (username, password, etc.) from buffer using the + // 2 byte prefix for length static void write_uint16_length(const size_t size, Buffer &buf) { if (size > 0xFFFF) @@ -1446,6 +1492,11 @@ class ProtoContext buf.null_terminate(); } + static void write_empty_string(Buffer &buf) + { + write_uint16_length(0, buf); + } + template static S read_control_string(const Buffer &buf) { @@ -1469,17 +1520,7 @@ class ProtoContext control_send(std::move(bp)); } - static unsigned char *skip_string(Buffer &buf) - { - const size_t len = read_uint16_length(buf); - return buf.read_alloc(len); - } - - static void write_empty_string(Buffer &buf) - { - write_uint16_length(0, buf); - } - + protected: // Packet structure for managing network packets, passed as a template // parameter to ProtoStackBase class Packet @@ -2785,7 +2826,7 @@ class ProtoContext const std::string username = read_auth_string(*buf); const SafeString password = read_auth_string(*buf); const std::string peer_info = read_auth_string(*buf); - proto.server_auth(username, password, peer_info, Base::auth_cert()); + proto.proto_callback->server_auth(username, password, peer_info, Base::auth_cert()); } } @@ -3641,9 +3682,11 @@ class ProtoContext OPENVPN_SIMPLE_EXCEPTION(select_key_context_error); - ProtoContext(const ProtoConfig::Ptr &config_arg, // configuration + ProtoContext(ProtoContextCallbackInterface *cb_arg, + const ProtoConfig::Ptr &config_arg, // configuration const SessionStats::Ptr &stats_arg) // error stats - : config(config_arg), + : proto_callback(cb_arg), + config(config_arg), stats(stats_arg), mode_(config_arg->ssl_factory->mode()), n_key_ids(0), @@ -4263,8 +4306,13 @@ class ProtoContext return *stats; } - protected: // debugging + bool is_state_client_wait_reset_ack() const + { + return primary_state() == C_WAIT_RESET_ACK; + } + + protected: int primary_state() const { if (primary) @@ -4291,32 +4339,11 @@ class ProtoContext secondary.reset(); } - virtual void control_net_send(const Buffer &net_buf) = 0; - - // app may take ownership of app_bp via std::move - virtual void control_recv(BufferPtr &&app_bp) = 0; - // Called on client to request username/password credentials. - // Should be overriden by derived class if credentials are required. - // username and password should be written into buf with write_auth_string(). - virtual void client_auth(Buffer &buf) - { - write_empty_string(buf); // username - write_empty_string(buf); // password - } - - // Called on server with credentials and peer info provided by client. - // Should be overriden by derived class if credentials are required. - virtual void server_auth(const std::string &username, - const SafeString &password, - const std::string &peer_info, - const AuthCert::Ptr &auth_cert) - { - } - - // Called when KeyContext transitions to ACTIVE state - virtual void active(bool primary) + // delegated to the callback/parent + void client_auth(Buffer &buf) { + proto_callback->client_auth(buf); } void update_last_received() @@ -4326,12 +4353,12 @@ class ProtoContext void net_send(const unsigned int key_id, const Packet &net_pkt) { - control_net_send(net_pkt.buffer()); + proto_callback->control_net_send(net_pkt.buffer()); } void app_recv(const unsigned int key_id, BufferPtr &&to_app_buf) { - control_recv(std::move(to_app_buf)); + proto_callback->control_recv(std::move(to_app_buf)); } // we're getting a request from peer to renegotiate. @@ -4471,7 +4498,7 @@ class ProtoContext case KeyContext::KEV_ACTIVE: OPENVPN_LOG_PROTO_VERBOSE(debug_prefix() << " SESSION_ACTIVE"); primary->rekey(CryptoDCInstance::ACTIVATE_PRIMARY); - active(true); + proto_callback->active(true); break; case KeyContext::KEV_RENEGOTIATE: case KeyContext::KEV_RENEGOTIATE_FORCE: @@ -4511,7 +4538,7 @@ class ProtoContext secondary->rekey(CryptoDCInstance::NEW_SECONDARY); if (primary) primary->prepare_expire(); - active(false); + proto_callback->active(false); break; case KeyContext::KEV_BECOME_PRIMARY: if (!secondary->invalidated()) @@ -4590,6 +4617,13 @@ class ProtoContext // BEGIN ProtoContext data members + /** the class that uses this class needs to be called back on a few things. Typically a class + * that uses this class as field for composition. This parent/callback class needs to ensure that + * it lives longer than this class, e.g. by having this class as field as this class blindly + * assumes that this pointer is always valid for its lifetime + */ + ProtoContextCallbackInterface *proto_callback; + ProtoConfig::Ptr config; SessionStats::Ptr stats; diff --git a/test/unittests/test_proto.cpp b/test/unittests/test_proto.cpp index 3264e6ec..19232bdb 100644 --- a/test/unittests/test_proto.cpp +++ b/test/unittests/test_proto.cpp @@ -353,24 +353,20 @@ class DroughtMeasure }; // test the OpenVPN protocol implementation in ProtoContext -class TestProto : public ProtoContext +class TestProto : public ProtoContextCallbackInterface { - typedef ProtoContext Base; + /* Callback methods that are not used */ + void active(bool primary) override + { + } - using Base::is_server; - using Base::mode; - using Base::now; public: - using Base::flush; - - typedef Base::PacketType PacketType; - OPENVPN_EXCEPTION(session_invalidated); - TestProto(const Base::ProtoConfig::Ptr &config, + TestProto(const ProtoContext::ProtoConfig::Ptr &config, const SessionStats::Ptr &stats) - : Base(config, stats), + : proto_context(this, config, stats), control_drought("control", config->now), data_drought("data", config->now), frame(config->frame) @@ -382,27 +378,26 @@ class TestProto : public ProtoContext void reset() { net_out.clear(); - Base::reset(); - - Base::conf().mss_parms.mssfix = MSSParms::MSSFIX_DEFAULT; + proto_context.reset(); + proto_context.conf().mss_parms.mssfix = MSSParms::MSSFIX_DEFAULT; } void initial_app_send(const char *msg) { - Base::start(); + proto_context.start(); const size_t msglen = std::strlen(msg) + 1; BufferAllocated app_buf((unsigned char *)msg, msglen, 0); copy_progress(app_buf); control_send(std::move(app_buf)); - flush(true); + proto_context.flush(true); } void app_send_templ_init(const char *msg) { - Base::start(); + proto_context.start(); const size_t msglen = std::strlen(msg) + 1; templ.reset(new BufferAllocated((unsigned char *)msg, msglen, 0)); - flush(true); + proto_context.flush(true); } void app_send_templ() @@ -421,9 +416,9 @@ class TestProto : public ProtoContext bool do_housekeeping() { - if (now() >= Base::next_housekeeping()) + if (proto_context.now() >= proto_context.next_housekeeping()) { - Base::housekeeping(); + proto_context.housekeeping(); return true; } else @@ -433,13 +428,13 @@ class TestProto : public ProtoContext void control_send(BufferPtr &&app_bp) { app_bytes_ += app_bp->size(); - Base::control_send(std::move(app_bp)); + proto_context.control_send(std::move(app_bp)); } void control_send(BufferAllocated &&app_buf) { app_bytes_ += app_buf.size(); - Base::control_send(std::move(app_buf)); + proto_context.control_send(std::move(app_buf)); } BufferPtr data_encrypt_string(const char *str) @@ -453,12 +448,12 @@ class TestProto : public ProtoContext void data_encrypt(BufferAllocated &in_out) { - Base::data_encrypt(in_out); + proto_context.data_encrypt(in_out); } - void data_decrypt(const PacketType &type, BufferAllocated &in_out) + void data_decrypt(const ProtoContext::PacketType &type, BufferAllocated &in_out) { - Base::data_decrypt(type, in_out); + proto_context.data_decrypt(type, in_out); if (in_out.size()) { data_bytes_ += in_out.size(); @@ -500,13 +495,8 @@ class TestProto : public ProtoContext void check_invalidated() { - if (Base::invalidated()) - throw session_invalidated(Error::name(Base::invalidation_reason())); - } - - bool is_state_client_wait_reset_ack() const - { - return primary_state() == C_WAIT_RESET_ACK; + if (proto_context.invalidated()) + throw session_invalidated(Error::name(proto_context.invalidation_reason())); } void disable_xmit() @@ -514,13 +504,15 @@ class TestProto : public ProtoContext disable_xmit_ = true; } + ProtoContext proto_context; + std::deque net_out; DroughtMeasure control_drought; DroughtMeasure data_drought; private: - virtual void control_net_send(const Buffer &net_buf) + void control_net_send(const Buffer &net_buf) override { if (disable_xmit_) return; @@ -528,7 +520,7 @@ class TestProto : public ProtoContext net_out.push_back(BufferPtr(new BufferAllocated(net_buf, 0))); } - virtual void control_recv(BufferPtr &&app_bp) + void control_recv(BufferPtr &&app_bp) override { BufferPtr work; work.swap(app_bp); @@ -559,7 +551,7 @@ class TestProto : public ProtoContext void modmsg(BufferPtr &buf) { char *msg = (char *)buf->data(); - if (is_server()) + if (proto_context.is_server()) { msg[8] = 'S'; msg[11] = 'C'; @@ -602,9 +594,9 @@ class TestProtoClient : public TestProto typedef TestProto Base; public: - TestProtoClient(const Base::ProtoConfig::Ptr &config, + TestProtoClient(const ProtoContext::ProtoConfig::Ptr &config, const SessionStats::Ptr &stats) - : Base(config, stats) + : TestProto(config, stats) { } @@ -613,21 +605,26 @@ class TestProtoClient : public TestProto { const std::string username("foo"); const std::string password("bar"); - Base::write_auth_string(username, buf); - Base::write_auth_string(password, buf); + ProtoContext::write_auth_string(username, buf); + ProtoContext::write_auth_string(password, buf); } }; class TestProtoServer : public TestProto { - typedef TestProto Base; public: + void start() + { + proto_context.start(); + } + OPENVPN_SIMPLE_EXCEPTION(auth_failed); - TestProtoServer(const Base::ProtoConfig::Ptr &config, + + TestProtoServer(const ProtoContext::ProtoConfig::Ptr &config, const SessionStats::Ptr &stats) - : Base(config, stats) + : TestProto(config, stats) { } @@ -687,7 +684,7 @@ class NoisyWire a.app_send_templ(); // queue a data channel packet - if (a.data_channel_ready()) + if (a.proto_context.data_channel_ready()) { BufferPtr bp = a.data_encrypt_string("Waiting for godot A... Waiting for godot B... Waiting for godot C... Waiting for godot D... Waiting for godot E... Waiting for godot F... Waiting for godot G... Waiting for godot H... Waiting for godot I... Waiting for godot J..."); wire.push_back(bp); @@ -710,14 +707,14 @@ class NoisyWire BufferPtr bp = recv(); if (!bp) break; - typename T2::PacketType pt = b.packet_type(*bp); + typename ProtoContext::PacketType pt = b.proto_context.packet_type(*bp); if (pt.is_control()) { #ifdef VERBOSE if (!b.control_net_validate(pt, *bp)) // not strictly necessary since control_net_recv will also validate std::cout << now->raw() << " " << title << " CONTROL PACKET VALIDATION FAILED" << std::endl; #endif - b.control_net_recv(pt, std::move(bp)); + b.proto_context.control_net_recv(pt, std::move(bp)); } else if (pt.is_data()) { @@ -744,11 +741,11 @@ class NoisyWire #ifdef VERBOSE std::cout << now->raw() << " " << title << " KEY_STATE_ERROR" << std::endl; #endif - b.stat().error(Error::KEY_STATE_ERROR); + b.proto_context.stat().error(Error::KEY_STATE_ERROR); } #ifdef SIMULATE_UDP_AMPLIFY_ATTACK - if (b.is_state_client_wait_reset_ack()) + if (b.proto_context.is_state_client_wait_reset_ack()) { b.disable_xmit(); #ifdef VERBOSE @@ -757,7 +754,7 @@ class NoisyWire } #endif } - b.flush(true); + b.proto_context.flush(true); } private: @@ -1148,8 +1145,8 @@ int test(const int thread_num) << " CTRL=" << cli_proto.n_control_recv() << '/' << cli_proto.n_control_send() << '/' << serv_proto.n_control_recv() << '/' << serv_proto.n_control_send() #endif << " D=" << cli_proto.control_drought().raw() << '/' << cli_proto.data_drought().raw() << '/' << serv_proto.control_drought().raw() << '/' << serv_proto.data_drought().raw() - << " N=" << cli_proto.negotiations() << '/' << serv_proto.negotiations() - << " SH=" << cli_proto.slowest_handshake().raw() << '/' << serv_proto.slowest_handshake().raw() + << " N=" << cli_proto.proto_context.negotiations() << '/' << serv_proto.proto_context.negotiations() + << " SH=" << cli_proto.proto_context.slowest_handshake().raw() << '/' << serv_proto.proto_context.slowest_handshake().raw() << " HE=" << cli_stats->get_error_count(Error::HANDSHAKE_TIMEOUT) << '/' << serv_stats->get_error_count(Error::HANDSHAKE_TIMEOUT) << std::endl;