diff --git a/openvpn/common/base64.hpp b/openvpn/common/base64.hpp index 36a63bb4..4bd15cda 100644 --- a/openvpn/common/base64.hpp +++ b/openvpn/common/base64.hpp @@ -56,6 +56,7 @@ namespace openvpn { OPENVPN_SIMPLE_EXCEPTION(base64_decode_error); // altmap is "+/=" by default + // another possible encoding for URLs: "-_." Base64(const char *altmap = nullptr) { // build encoding map @@ -72,6 +73,8 @@ namespace openvpn { } if (!altmap) altmap = "+/="; + if (std::strlen(altmap) != 3) + throw base64_bad_map(); enc[62] = altmap[0]; enc[63] = altmap[1]; equal = altmap[2]; @@ -162,6 +165,30 @@ namespace openvpn { } } + template + bool is_base64(const V& data, const size_t expected_decoded_length) const + { + const size_t size = data.size(); + if (size != encoded_len(expected_decoded_length)) + return false; + const size_t eq_begin = size - num_eq(expected_decoded_length); + for (size_t i = 0; i < size; ++i) + { + const char c = data[i]; + if (i < eq_begin) + { + if (!is_base64_char(c)) + return false; + } + else + { + if (c != equal) + return false; + } + } + return true; + } + private: bool is_base64_char(const char c) const { @@ -202,6 +229,16 @@ namespace openvpn { return val; } + static size_t encoded_len(const size_t decoded_len) + { + return (decoded_len * 4 / 3 + 3) & ~3; + } + + static size_t num_eq(const size_t decoded_len) + { + return (-1 - decoded_len) % 3; + } + unsigned char enc[64]; unsigned char dec[128]; unsigned char equal; @@ -210,11 +247,14 @@ namespace openvpn { // provide a static Base64 object OPENVPN_EXTERN const Base64* base64; // GLOBAL + OPENVPN_EXTERN const Base64* base64_urlsafe; // GLOBAL inline void base64_init_static() { if (!base64) base64 = new Base64(); + if (!base64_urlsafe) + base64_urlsafe = new Base64("-_."); } inline void base64_uninit_static() @@ -224,6 +264,11 @@ namespace openvpn { delete base64; base64 = nullptr; } + if (base64_urlsafe) + { + delete base64_urlsafe; + base64_urlsafe = nullptr; + } } }