0
0
mirror of https://github.com/signalapp/libsignal.git synced 2024-09-20 03:52:17 +02:00

Use static dispatch for proto unknown field detection

Add a derive macro and attach it to each generated protobuf message. The 
generated code will walk each field in the message and dispatch recursively to 
the same trait to find all unknown fields. Keep the existing 
dynamically-dispatched descriptor-walking implementation since it's easier to 
understand, but only use it to ensure parity with the macro-generated version 
via test cases.
This commit is contained in:
Alex Konradi 2024-01-23 09:05:29 -05:00 committed by GitHub
parent 1158caf302
commit 9dc14c6726
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 742 additions and 69 deletions

13
Cargo.lock generated
View File

@ -1806,6 +1806,7 @@ dependencies = [
"hex-literal",
"hkdf",
"hmac",
"libsignal-message-backup-macros",
"libsignal-protocol",
"log",
"protobuf",
@ -1821,6 +1822,18 @@ dependencies = [
"zkgroup",
]
[[package]]
name = "libsignal-message-backup-macros"
version = "0.1.0"
dependencies = [
"heck 0.3.3",
"lazy_static",
"proc-macro2",
"quote",
"syn 2.0.48",
"test-case",
]
[[package]]
name = "libsignal-net"
version = "0.1.0"

View File

@ -46,7 +46,7 @@
<h2>Overview of licenses:</h2>
<ul class="licenses-overview">
<li><a href="#MIT">MIT License</a> (281)</li>
<li><a href="#AGPL-3.0">GNU Affero General Public License v3.0</a> (23)</li>
<li><a href="#AGPL-3.0">GNU Affero General Public License v3.0</a> (24)</li>
<li><a href="#Apache-2.0">Apache License 2.0</a> (11)</li>
<li><a href="#BSD-3-Clause">BSD 3-Clause &quot;New&quot; or &quot;Revised&quot; License</a> (8)</li>
<li><a href="#ISC">ISC License</a> (3)</li>
@ -740,6 +740,7 @@ For more information on this, and how to apply and follow the GNU AGPL, see
<li><a href="https://crates.io/crates/libsignal-ffi">libsignal-ffi 0.39.2</a></li>
<li><a href="https://crates.io/crates/libsignal-jni">libsignal-jni 0.39.2</a></li>
<li><a href="https://crates.io/crates/libsignal-message-backup">libsignal-message-backup 0.1.0</a></li>
<li><a href="https://crates.io/crates/libsignal-message-backup-macros">libsignal-message-backup-macros 0.1.0</a></li>
<li><a href="https://crates.io/crates/libsignal-net">libsignal-net 0.1.0</a></li>
<li><a href="https://crates.io/crates/libsignal-node">libsignal-node 0.39.2</a></li>
<li><a href="https://crates.io/crates/libsignal-protocol">libsignal-protocol 0.1.0</a></li>

View File

@ -669,7 +669,7 @@ For more information on this, and how to apply and follow the GNU AGPL, see
```
## attest 0.1.0, device-transfer 0.1.0, libsignal-bridge 0.1.0, libsignal-bridge-macros 0.1.0, libsignal-core 0.1.0, libsignal-ffi 0.39.2, libsignal-jni 0.39.2, libsignal-message-backup 0.1.0, libsignal-net 0.1.0, libsignal-node 0.39.2, libsignal-protocol 0.1.0, libsignal-svr3 0.1.0, poksho 0.7.0, signal-crypto 0.1.0, signal-media 0.1.0, signal-neon-futures 0.1.0, signal-neon-futures-tests 0.1.0, signal-pin 0.1.0, usernames 0.1.0, zkcredential 0.1.0, zkgroup 0.9.0
## attest 0.1.0, device-transfer 0.1.0, libsignal-bridge 0.1.0, libsignal-bridge-macros 0.1.0, libsignal-core 0.1.0, libsignal-ffi 0.39.2, libsignal-jni 0.39.2, libsignal-message-backup 0.1.0, libsignal-message-backup-macros 0.1.0, libsignal-net 0.1.0, libsignal-node 0.39.2, libsignal-protocol 0.1.0, libsignal-svr3 0.1.0, poksho 0.7.0, signal-crypto 0.1.0, signal-media 0.1.0, signal-neon-futures 0.1.0, signal-neon-futures-tests 0.1.0, signal-pin 0.1.0, usernames 0.1.0, zkcredential 0.1.0, zkgroup 0.9.0
```
GNU AFFERO GENERAL PUBLIC LICENSE

View File

@ -924,7 +924,7 @@ You should also get your employer (if you work as a programmer) or school, if an
<key>License</key>
<string>GNU Affero General Public License v3.0</string>
<key>Title</key>
<string>attest 0.1.0, device-transfer 0.1.0, libsignal-bridge 0.1.0, libsignal-bridge-macros 0.1.0, libsignal-core 0.1.0, libsignal-ffi 0.39.2, libsignal-jni 0.39.2, libsignal-message-backup 0.1.0, libsignal-net 0.1.0, libsignal-node 0.39.2, libsignal-protocol 0.1.0, libsignal-svr3 0.1.0, poksho 0.7.0, signal-crypto 0.1.0, signal-media 0.1.0, signal-neon-futures 0.1.0, signal-neon-futures-tests 0.1.0, signal-pin 0.1.0, usernames 0.1.0, zkcredential 0.1.0, zkgroup 0.9.0</string>
<string>attest 0.1.0, device-transfer 0.1.0, libsignal-bridge 0.1.0, libsignal-bridge-macros 0.1.0, libsignal-core 0.1.0, libsignal-ffi 0.39.2, libsignal-jni 0.39.2, libsignal-message-backup 0.1.0, libsignal-message-backup-macros 0.1.0, libsignal-net 0.1.0, libsignal-node 0.39.2, libsignal-protocol 0.1.0, libsignal-svr3 0.1.0, poksho 0.7.0, signal-crypto 0.1.0, signal-media 0.1.0, signal-neon-futures 0.1.0, signal-neon-futures-tests 0.1.0, signal-pin 0.1.0, usernames 0.1.0, zkcredential 0.1.0, zkgroup 0.9.0</string>
<key>Type</key>
<string>PSGroupSpecifier</string>
</dict>

View File

@ -10,6 +10,7 @@ authors = ["Signal Messenger LLC"]
license = "AGPL-3.0-only"
[dependencies]
libsignal-message-backup-macros = { path = "macros" }
libsignal-protocol = { path = "../protocol" }
signal-crypto = { path = "../crypto" }
usernames = { path = "../usernames" }

View File

@ -3,30 +3,75 @@
// SPDX-License-Identifier: AGPL-3.0-only
//
use std::io::Write as _;
use protobuf_codegen::{Customize, CustomizeCallback};
const DERIVE_LINE: &str = "#[derive(crate::unknown::visit_static::VisitUnknownFields)]";
struct DeriveVisitUnknownFields;
impl CustomizeCallback for DeriveVisitUnknownFields {
fn field(&self, field: &protobuf::reflect::FieldDescriptor) -> Customize {
Customize::default().before(&format!("#[field_name({:?})]", field.name()))
}
fn message(&self, _: &protobuf::reflect::MessageDescriptor) -> Customize {
Customize::default().before(DERIVE_LINE)
}
fn enumeration(&self, _: &protobuf::reflect::EnumDescriptor) -> Customize {
Customize::default().before(DERIVE_LINE)
}
fn oneof(&self, _: &protobuf::reflect::OneofDescriptor) -> Customize {
Customize::default().before(DERIVE_LINE)
}
}
fn main() {
const PROTOS: &[&str] = &["src/proto/backup.proto", "src/proto/test.proto"];
const PROTOS_DIR: &str = "protos";
protobuf_codegen::Codegen::new()
.protoc()
.protoc_extra_arg(
// Enable optional fields. This isn't needed in the most recent
// protobuf compiler version, but adding it lets us support older
// versions that might be installed in CI or on developer machines.
"--experimental_allow_proto3_optional",
)
.include("src")
.inputs(PROTOS)
.cargo_out_dir(PROTOS_DIR)
let out_dir = format!(
"{}/{PROTOS_DIR}",
std::env::var("OUT_DIR").expect("OUT_DIR env var not set")
);
std::fs::create_dir_all(&out_dir).expect("failed to create output directory");
let make_codegen = || {
let mut codegen = protobuf_codegen::Codegen::new();
codegen
.protoc()
.protoc_extra_arg(
// Enable optional fields. This isn't needed in the most recent
// protobuf compiler version, but adding it lets us support older
// versions that might be installed in CI or on developer machines.
"--experimental_allow_proto3_optional",
)
.customize_callback(DeriveVisitUnknownFields)
// Use the lite runtime to reduce code size.
.customize(Customize::default().lite_runtime(true))
.include("src")
.out_dir(&out_dir);
codegen
};
// For the test-only protos, use the full runtime instead of the lite
// runtime. This lets us test the dynamic and static unknown field dispatch.
const TEST_PROTOS: &[&str] = &["src/proto/test.proto"];
make_codegen()
.inputs(TEST_PROTOS)
.customize(Customize::default().lite_runtime(false))
.run_from_script();
// Mark the test.proto module as test-only.
let out_mod_rs = format!("{}/{PROTOS_DIR}/mod.rs", std::env::var("OUT_DIR").unwrap());
let mut contents = std::fs::read_to_string(&out_mod_rs).unwrap();
let insert_pos = contents.find("pub mod test;").unwrap_or(0);
const PROTOS: &[&str] = &["src/proto/backup.proto"];
make_codegen().inputs(PROTOS).run_from_script();
contents.insert_str(insert_pos, "\n#[cfg(test)] // only for testing\n");
std::fs::write(out_mod_rs, contents).unwrap();
// Add the test.proto module to mod.rs as test-only.
let out_mod_rs = format!("{out_dir}/mod.rs");
std::fs::OpenOptions::new()
.append(true)
.open(&out_mod_rs)
.unwrap_or_else(|e| panic!("expected {out_mod_rs} to be writable, got {e}"))
.write_all(b" #[cfg(test)] pub mod test; ")
.expect("failed to write");
for proto in PROTOS {
println!("cargo:rerun-if-changed={}", proto);

View File

@ -0,0 +1,23 @@
#
# Copyright (C) 2024 Signal Messenger, LLC.
# SPDX-License-Identifier: AGPL-3.0-only
#
[package]
name = "libsignal-message-backup-macros"
version = "0.1.0"
edition = "2021"
authors = ["Signal Messenger LLC"]
license = "AGPL-3.0-only"
[lib]
proc_macro = true
[dependencies]
heck = "0.3.1"
lazy_static = "1.4.0"
proc-macro2 = "1.0.74"
quote = "1.0.35"
syn = { version = "2.0.46", features = ["full", "extra-traits"] }
[dev-dependencies]
test-case = "3.3.1"

View File

@ -0,0 +1,354 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
use std::borrow::Cow;
use heck::SnakeCase;
use proc_macro::TokenStream;
use proc_macro2::{Delimiter, Group, TokenStream as TokenStream2};
use quote::{quote, ToTokens};
use syn::spanned::Spanned;
use syn::{
self, parse2, parse_macro_input, Attribute, DeriveInput, Field, Ident, LitStr, MetaList,
TypePath,
};
macro_rules! tokens_alias {
($name:ident, $path:path) => {
struct $name;
impl ToTokens for $name {
fn to_tokens(&self, tokens: &mut TokenStream2) {
quote!($path).to_tokens(tokens)
}
}
};
}
tokens_alias!(
VisitUnknownFields,
crate::unknown::visit_static::VisitUnknownFields
);
tokens_alias!(
VisitContainerUnknownFields,
crate::unknown::visit_static::VisitContainerUnknownFields
);
tokens_alias!(UnknownFieldVisitor, crate::unknown::UnknownFieldVisitor);
tokens_alias!(Visitor, crate::unknown::visit_static::Visitor);
tokens_alias!(PathType, crate::unknown::Path<'_>);
tokens_alias!(Path, crate::unknown::Path);
tokens_alias!(Part, crate::unknown::Part);
tokens_alias!(VisitorArgName, visitor);
tokens_alias!(PathArgName, path);
#[proc_macro_derive(VisitUnknownFields, attributes(field_name))]
pub fn derive_visit_unknown_fields(item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item);
derive_visit_unknown_fields_impl(input).into()
}
fn derive_visit_unknown_fields_impl(input: DeriveInput) -> TokenStream2 {
if input.generics.lifetimes().next().is_some()
|| input.generics.type_params().next().is_some()
|| input.generics.const_params().next().is_some()
{
return syn::Error::new_spanned(input.generics, "generics are not supported")
.into_compile_error();
}
match input.data {
syn::Data::Union(u) => {
syn::Error::new_spanned(u.union_token, "unions are not supported").into_compile_error()
}
syn::Data::Enum(e) => derive_has_unknown_fields_enum_impl(input.ident, e),
syn::Data::Struct(e) => derive_has_unknown_fields_struct_impl(input.ident, e),
}
}
fn derive_has_unknown_fields_struct_impl(ident: Ident, e: syn::DataStruct) -> TokenStream2 {
let visit_fields: Vec<_> = e.fields.into_iter().map(VisitField::from).collect();
let field_idents = visit_fields.iter().map(|f| &f.ident);
let destruct = Group::new(Delimiter::Brace, quote!(#(#field_idents),*));
quote! {
impl #VisitUnknownFields for #ident {
fn visit_unknown_fields(&self, #PathArgName: #PathType, #VisitorArgName: &mut impl #Visitor) -> std::ops::ControlFlow<()>{
let Self #destruct = self;
#({ #visit_fields };)*
std::ops::ControlFlow::Continue(())
}
}
}
}
fn derive_has_unknown_fields_enum_impl(ident: Ident, e: syn::DataEnum) -> TokenStream2 {
let arms = e.variants.into_iter().map(|variant| {
let ident = &variant.ident;
let delimiter = match &variant.fields {
syn::Fields::Unnamed(_) => Delimiter::Parenthesis,
syn::Fields::Unit => Delimiter::None,
syn::Fields::Named(_) => {
unreachable!("generated protobuf code doesn't have enum variants with named fields")
}
};
// This is either an enum or oneof in a protobuf. If there is exactly
// one field, generate a name for the inner field from the variant name.
let mut fields = variant.fields;
{
let mut it = fields.iter_mut();
if let (Some(first), None) = (it.next(), it.next()) {
let field_name = {
let candidate = variant.ident.to_string().to_snake_case();
match candidate.as_str() {
"self" => "self_".to_string(),
_ => candidate,
}
};
first.ident = Some(Ident::new(&field_name, first.span()))
}
}
let visit_fields: Vec<_> = fields.into_iter().map(VisitField::from).collect();
let field_names = visit_fields.iter().map(|f| &f.ident);
let fields = Group::new(delimiter, quote!(#(#field_names),*));
quote! {
Self::#ident #fields => { #(#visit_fields);* }
}
});
quote! {
impl #VisitUnknownFields for #ident {
fn visit_unknown_fields(&self, path: #PathType, #VisitorArgName: &mut impl #Visitor) -> std::ops::ControlFlow<()> {
match self {
#(#arms,)*
};
std::ops::ControlFlow::Continue(())
}
}
}
}
/// Produces a token stream for visiting a single field.
struct VisitField {
ident: Ident,
proto_field_name: Option<String>,
field_type: FieldType,
}
impl ToTokens for VisitField {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let Self {
ident: field_ident,
field_type,
proto_field_name,
} = self;
let field_name = proto_field_name
.as_ref()
.map(Cow::Borrowed)
.unwrap_or_else(|| Cow::Owned(field_ident.to_string()));
match field_type {
FieldType::Single => quote! {
let path = #Path::Branch {
parent: &#PathArgName,
field_name: #field_name,
part: #Part::Field,
};
#VisitUnknownFields::visit_unknown_fields(#field_ident, #PathArgName, #VisitorArgName)?
},
FieldType::Container => quote! {
#VisitContainerUnknownFields::visit_unknown_fields_within(
#field_ident,
#PathArgName,
#field_name,
#VisitorArgName
)?
},
}.to_tokens(tokens)
}
}
impl From<Field> for VisitField {
fn from(field: Field) -> Self {
let proto_field_name = field
.attrs
.into_iter()
.find_map(FieldNameAttr::new)
.map(|f| f.field_name);
let ident = field.ident.expect("tuple structs aren't supported");
let field_type = field.ty.into();
Self {
ident,
proto_field_name,
field_type,
}
}
}
/// Parsed attribute that indicates the name of a field in the source .proto
/// file.
///
/// The attribute looks like `#[field_name("protoName")]` where the string
/// literal `protoName` is the original name of the field in the .proto file.
struct FieldNameAttr {
field_name: String,
}
impl FieldNameAttr {
const ATTR_LABEL: &'static str = "field_name";
fn new(attr: Attribute) -> Option<Self> {
match attr.meta {
syn::Meta::List(MetaList { path, tokens, .. }) if path.is_ident(Self::ATTR_LABEL) => {
let str: LitStr = parse2(tokens).expect("not a string literal");
Some(Self {
field_name: str.value(),
})
}
_ => None,
}
}
}
/// Whether a field is a scalar or a container.
enum FieldType {
Single,
/// Some kind of wrapped field, or a `protobuf::SpecialFields`.
Container,
}
impl From<syn::Type> for FieldType {
fn from(value: syn::Type) -> Self {
if let syn::Type::Path(TypePath { path, qself: None }) = value {
let p = path.to_token_stream().to_string();
if p.starts_with(&quote!(::std::vec::Vec).to_string()) {
return FieldType::Container;
}
if p.starts_with(&quote!(::std::collections::HashMap).to_string()) {
return FieldType::Container;
}
if p.starts_with(&quote!(::std::option::Option).to_string()) {
return FieldType::Container;
}
if p.starts_with(&quote!(::protobuf::SpecialFields).to_string()) {
return FieldType::Container;
}
}
FieldType::Single
}
}
#[cfg(test)]
mod test {
use quote::ToTokens;
use syn::parse_quote;
use test_case::test_case;
use super::*;
fn message() -> (syn::DeriveInput, syn::ItemImpl) {
(
parse_quote! {
struct Foo {
pub pub_field: bool,
priv_field: String,
}
},
parse_quote! {
impl crate::unknown::visit_static::VisitUnknownFields for Foo {
fn visit_unknown_fields(
&self,
path: crate::unknown::Path<'_>,
visitor: &mut impl crate::unknown::visit_static::Visitor) -> std::ops::ControlFlow<()>
{
let Self {
pub_field, priv_field
} = self;
{
let path = crate::unknown::Path::Branch {
parent: & path,
field_name: "pub_field",
part: crate::unknown::Part::Field,
};
crate::unknown::visit_static::VisitUnknownFields::visit_unknown_fields(pub_field, path, visitor)?
};
{
let path = crate::unknown::Path::Branch {
parent: & path,
field_name: "priv_field",
part: crate::unknown::Part::Field,
};
crate::unknown::visit_static::VisitUnknownFields::visit_unknown_fields(priv_field, path, visitor)?
};
std::ops::ControlFlow::Continue(())
}
}
},
)
}
fn oneof() -> (syn::DeriveInput, syn::ItemImpl) {
(
parse_quote! {
enum Foo {
AField(AField),
BField(BField),
}
},
parse_quote! {
impl crate::unknown::visit_static::VisitUnknownFields for Foo {
fn visit_unknown_fields(
&self,
path: crate::unknown::Path<'_>,
visitor: &mut impl crate::unknown::visit_static::Visitor) -> std::ops::ControlFlow<()>
{
match self {
Self::AField(a_field) => {
let path = crate::unknown::Path::Branch {
parent: & path,
field_name: "a_field",
part: crate::unknown::Part::Field,
};
crate::unknown::visit_static::VisitUnknownFields::visit_unknown_fields(a_field, path, visitor)?
},
Self::BField(b_field) => {
let path = crate::unknown::Path::Branch {
parent: & path,
field_name: "b_field",
part: crate::unknown::Part::Field,
};
crate::unknown::visit_static::VisitUnknownFields::visit_unknown_fields(b_field, path, visitor)?
},
};
std::ops::ControlFlow::Continue(())
}
}
},
)
}
#[test_case(message)]
#[test_case(oneof)]
fn has_unknown_fields(input_and_output: fn() -> (syn::DeriveInput, syn::ItemImpl)) {
let (type_definition, expected_impl) = input_and_output();
let tokens = derive_visit_unknown_fields_impl(type_definition);
println!("{tokens}");
let output: syn::ItemImpl = syn::parse2(tokens).unwrap();
assert!(
output == expected_impl,
"got:\n{}\nwanted:\n{}",
output.to_token_stream(),
expected_impl.to_token_stream()
);
}
}

View File

@ -4,7 +4,6 @@
//
use libsignal_protocol::Aci;
use protobuf::MessageDyn;
#[derive(Debug, thiserror::Error)]
pub(crate) enum ParseHexError<const N: usize> {
@ -42,16 +41,16 @@ pub(crate) enum ParseVerbosity {
PrintPretty,
}
fn print_oneline(message: &dyn MessageDyn) {
fn print_oneline(message: &dyn std::fmt::Debug) {
eprintln!("{message:?}")
}
fn print_pretty(message: &dyn MessageDyn) {
fn print_pretty(message: &dyn std::fmt::Debug) {
eprintln!("{message:#?}")
}
impl ParseVerbosity {
pub(crate) fn into_visitor(self) -> Option<fn(&dyn MessageDyn)> {
pub(crate) fn into_visitor(self) -> Option<fn(&dyn std::fmt::Debug)> {
match self {
ParseVerbosity::None => None,
ParseVerbosity::PrintOneLine => Some(print_oneline),

View File

@ -6,7 +6,7 @@
//! Signal remote message backup utilities.
//!
use futures::{AsyncRead, AsyncSeek};
use protobuf::{Message as _, MessageDyn};
use protobuf::Message as _;
use crate::key::MessageBackupKey;
use crate::parse::VarintDelimitedReader;
@ -21,7 +21,7 @@ pub mod unknown;
pub struct BackupReader<R> {
reader: VarintDelimitedReader<R>,
pub visitor: fn(&dyn MessageDyn),
pub visitor: fn(&dyn std::fmt::Debug),
}
#[derive(Debug, thiserror::Error, displaydoc::Display)]
@ -107,7 +107,7 @@ impl<R: AsyncRead + AsyncSeek + Unpin> BackupReader<frame::FramesReader<R>> {
async fn read_all_frames<M: backup::method::Method>(
mut reader: VarintDelimitedReader<impl AsyncRead + Unpin>,
mut visitor: impl FnMut(&dyn MessageDyn),
mut visitor: impl FnMut(&dyn std::fmt::Debug),
unknown_fields: &mut impl Extend<FoundUnknownField>,
) -> Result<backup::PartialBackup<M>, Error> {
let mut add_found_unknown = |found_unknown: Vec<_>, index| {

View File

@ -7,10 +7,11 @@
use std::ops::{ControlFlow, Deref};
use protobuf::MessageFull;
#[cfg(test)]
mod visit_dyn;
pub(crate) mod visit_static;
/// Formatter for a sequence of [`PathPart`]s.
///
/// Provides a custom [`std::fmt::Display`] impl.
@ -57,7 +58,7 @@ pub enum UnknownValue {
///
/// Implemented as a singly linked list to avoid allocation.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum Path<'a> {
pub(crate) enum Path<'a> {
Root,
Branch {
parent: &'a Path<'a>,
@ -68,7 +69,7 @@ enum Path<'a> {
/// The part of a logical field that is being referenced.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum Part<'a> {
pub(crate) enum Part<'a> {
Field,
MapValue { key: MapKey<'a> },
Repeated { index: usize },
@ -76,7 +77,7 @@ enum Part<'a> {
/// Key in a protobuf map field.
#[derive(Copy, Clone, Debug, Eq, PartialEq, displaydoc::Display)]
enum MapKey<'a> {
pub(crate) enum MapKey<'a> {
/// {0}
U32(u32),
/// {0}
@ -91,6 +92,22 @@ enum MapKey<'a> {
String(&'a str),
}
macro_rules! impl_map_key_from {
($ty:ty, $constructor:ident $(, $maybe_borrow:tt)?) => {
impl<'a> From<&'a $ty> for MapKey<'a> {
fn from(value: &'a $ty) -> Self {
Self::$constructor($($maybe_borrow)? *value)
}
}
};
}
impl_map_key_from!(u32, U32);
impl_map_key_from!(u64, U64);
impl_map_key_from!(i32, I32);
impl_map_key_from!(i64, I64);
impl_map_key_from!(bool, Bool);
impl_map_key_from!(String, String, &);
impl Path<'_> {
fn owned_parts(&self) -> Vec<PathPart> {
let mut head = self;
@ -141,13 +158,10 @@ pub(crate) trait VisitUnknownFields {
pub trait UnknownFieldVisitor: FnMut(Vec<PathPart>, UnknownValue) -> ControlFlow<()> {}
impl<F: FnMut(Vec<PathPart>, UnknownValue) -> ControlFlow<()>> UnknownFieldVisitor for F {}
impl<M: MessageFull> VisitUnknownFields for M {
impl<M: visit_static::VisitUnknownFields> VisitUnknownFields for M {
fn visit_unknown_fields<F: UnknownFieldVisitor>(&self, mut visitor: F) {
// Currently implemented using dynamic traversal of protobuf
// descriptors.
// TODO: evaluate speed and code size versus statically-dispatched
// traversal.
let _: ControlFlow<()> = visit_dyn::visit_unknown_fields(self, Path::Root, &mut visitor);
let _: ControlFlow<()> =
visit_static::VisitUnknownFields::visit_unknown_fields(self, Path::Root, &mut visitor);
}
}
@ -187,8 +201,8 @@ impl<V: VisitUnknownFields> VisitUnknownFieldsExt for V {
mod test {
use std::collections::HashMap;
use protobuf::{Message, MessageFull};
use test_case::test_case;
use protobuf::Message;
use test_case::{test_case, test_matrix};
use super::*;
@ -234,12 +248,45 @@ mod test {
unreachable!("unexpectedly visited")
}
#[test_case(proto::TestMessage::default())]
#[test_case(proto::TestMessage::fake_data())]
#[test_case(proto::TestMessage::fake_data().wire_cast_as::<proto::TestMessage>())]
#[test_case(proto::TestMessage::fake_data().wire_cast_as::<proto::TestMessageWithExtraFields>())]
fn no_extra_fields(proto: impl MessageFull) {
proto.visit_unknown_fields(never_visits);
struct ViaProtoDescriptors<M>(M);
struct ViaStaticDispatch<M>(M);
impl VisitUnknownFields for ViaProtoDescriptors<proto::TestMessage> {
fn visit_unknown_fields<F: UnknownFieldVisitor>(&self, mut visitor: F) {
let _: ControlFlow<()> =
visit_dyn::visit_unknown_fields(&self.0, Path::Root, &mut visitor);
}
}
impl VisitUnknownFields for ViaProtoDescriptors<proto::TestMessageWithExtraFields> {
fn visit_unknown_fields<F: UnknownFieldVisitor>(&self, mut visitor: F) {
let _: ControlFlow<()> =
visit_dyn::visit_unknown_fields(&self.0, Path::Root, &mut visitor);
}
}
impl<M: visit_static::VisitUnknownFields> VisitUnknownFields for ViaStaticDispatch<M> {
fn visit_unknown_fields<F: UnknownFieldVisitor>(&self, mut visitor: F) {
let _: ControlFlow<()> = visit_static::VisitUnknownFields::visit_unknown_fields(
&self.0,
Path::Root,
&mut visitor,
);
}
}
#[test_matrix(
(
proto::TestMessage::default(),
proto::TestMessage::fake_data(),
proto::TestMessage::fake_data().wire_cast_as::<proto::TestMessage>(),
proto::TestMessage::fake_data().wire_cast_as::<proto::TestMessageWithExtraFields>(),
),
(ViaProtoDescriptors, ViaStaticDispatch)
)]
fn no_extra_fields<M, V: VisitUnknownFields>(proto: M, to_visitor: impl FnOnce(M) -> V) {
let visitor = to_visitor(proto);
visitor.visit_unknown_fields(never_visits);
}
macro_rules! modifier {
@ -288,36 +335,44 @@ mod test {
)])
);
#[test_case(oneof_extra_int, UnknownValue::Field {tag: 612})]
#[test_case(oneof_extra_string, UnknownValue::Field {tag: 611})]
#[test_case(oneof_extra_message, UnknownValue::Field {tag: 610})]
#[test_case(extra_string, UnknownValue::Field {tag: 701})]
#[test_case(extra_bytes, UnknownValue::Field {tag: 731})]
#[test_case(extra_int64, UnknownValue::Field {tag: 711})]
#[test_case(extra_repeated_message, UnknownValue::Field {tag: 721})]
#[test_case(extra_repeated_uint64, UnknownValue::Field {tag: 741})]
#[test_case(extra_enum, UnknownValue::Field {tag: 751})]
#[test_case(extra_nested, UnknownValue::Field {tag: 761})]
#[test_case(extra_map, UnknownValue::Field {tag: 771})]
fn has_unknown_fields_top_level(
modifier: fn(&mut proto::TestMessageWithExtraFields),
expected_value: UnknownValue,
#[test_matrix(
(
(oneof_extra_int, UnknownValue::Field {tag: 612}),
(oneof_extra_string, UnknownValue::Field {tag: 611}),
(oneof_extra_message, UnknownValue::Field {tag: 610}),
(extra_string, UnknownValue::Field {tag: 701}),
(extra_bytes, UnknownValue::Field {tag: 731}),
(extra_int64, UnknownValue::Field {tag: 711}),
(extra_repeated_message, UnknownValue::Field {tag: 721}),
(extra_repeated_uint64, UnknownValue::Field {tag: 741}),
(extra_enum, UnknownValue::Field {tag: 751}),
(extra_nested, UnknownValue::Field {tag: 761}),
(extra_map, UnknownValue::Field {tag: 771}),
),
(ViaProtoDescriptors, ViaStaticDispatch)
)]
fn has_unknown_fields_top_level<V: VisitUnknownFields>(
(modifier, expected_value): (fn(&mut proto::TestMessageWithExtraFields), UnknownValue),
to_visitor: impl FnOnce(proto::TestMessage) -> V,
) {
let mut message =
proto::TestMessage::fake_data().wire_cast_as::<proto::TestMessageWithExtraFields>();
modifier(&mut message);
let (path, value) = message
.wire_cast_as::<proto::TestMessage>()
.find_unknown_field()
.expect("has unknown");
let message = message.wire_cast_as::<proto::TestMessage>();
let visitor = to_visitor(message);
let (path, value) = visitor.find_unknown_field().expect("has unknown");
assert_eq!(value, expected_value);
assert_eq!(path, &[]);
}
#[test]
fn unknown_fields_in_nested_message() {
#[test_case(ViaProtoDescriptors)]
#[test_case(ViaStaticDispatch)]
fn unknown_fields_in_nested_message<V: VisitUnknownFields>(
to_visitor: impl FnOnce(proto::TestMessage) -> V,
) {
let message = proto::TestMessageWithExtraFields {
nested_message: Some(proto::TestMessageWithExtraFields {
repeated_message: vec![
@ -352,12 +407,12 @@ mod test {
UnknownValue::Field { tag: 711 },
),
(
"map[\"map_key\"].oneof_message.enum",
"map[\"map_key\"].oneof.oneof_message.enum",
UnknownValue::EnumValue { number: 3 },
),
];
let found: Vec<_> = message
let found: Vec<_> = to_visitor(message)
.collect_unknown_fields()
.into_iter()
.map(|(key, value)| {

View File

@ -28,6 +28,15 @@ pub(super) fn visit_unknown_fields(
}
for field in message.descriptor_dyn().fields() {
let containing_oneof = field.containing_oneof();
let path = match containing_oneof.as_ref() {
None => path,
Some(oneof) => Path::Branch {
parent: &path,
field_name: oneof.name(),
part: Part::Field,
},
};
visit_child_unknown_fields(field.get_reflect(message), path, field.name(), visitor)?
}

View File

@ -0,0 +1,173 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
use std::collections::HashMap;
use std::ops::ControlFlow;
use protobuf::{EnumOrUnknown, MessageField, SpecialFields, UnknownFields};
pub(crate) use libsignal_message_backup_macros::VisitUnknownFields;
use crate::unknown::{MapKey, Part, Path, UnknownFieldVisitor, UnknownValue};
pub(crate) trait Visitor {
fn unknown_fields(&mut self, path: Path<'_>, unknown: &UnknownFields) -> ControlFlow<()>;
fn unknown_enum(&mut self, path: Path<'_>, value: i32) -> ControlFlow<()>;
}
impl<U: UnknownFieldVisitor> Visitor for U {
fn unknown_fields(&mut self, path: Path<'_>, unknown: &UnknownFields) -> ControlFlow<()> {
for (tag, _value) in unknown {
self(path.owned_parts(), UnknownValue::Field { tag })?
}
ControlFlow::Continue(())
}
fn unknown_enum(&mut self, path: Path<'_>, value: i32) -> ControlFlow<()> {
self(
path.owned_parts(),
UnknownValue::EnumValue { number: value },
)
}
}
pub(crate) trait VisitUnknownFields {
/// Calls the visitor for each unknown field in the message.
fn visit_unknown_fields(&self, path: Path<'_>, visitor: &mut impl Visitor) -> ControlFlow<()>;
}
impl<V: VisitUnknownFields> VisitUnknownFields for &V {
fn visit_unknown_fields(&self, path: Path<'_>, visitor: &mut impl Visitor) -> ControlFlow<()> {
V::visit_unknown_fields(self, path, visitor)
}
}
impl<E: protobuf::Enum> VisitUnknownFields for EnumOrUnknown<E> {
fn visit_unknown_fields(&self, path: Path<'_>, visitor: &mut impl Visitor) -> ControlFlow<()> {
match self.enum_value() {
Ok(_) => ControlFlow::Continue(()),
Err(v) => visitor.unknown_enum(path, v),
}
}
}
impl<E: VisitUnknownFields> VisitUnknownFields for Box<E> {
fn visit_unknown_fields(&self, path: Path<'_>, visitor: &mut impl Visitor) -> ControlFlow<()> {
E::visit_unknown_fields(self, path, visitor)
}
}
impl<E: VisitUnknownFields> VisitUnknownFields for MessageField<E> {
fn visit_unknown_fields(&self, path: Path<'_>, visitor: &mut impl Visitor) -> ControlFlow<()> {
self.0.as_ref().map_or(ControlFlow::Continue(()), |inner| {
inner.visit_unknown_fields(path, visitor)
})
}
}
pub(crate) trait VisitContainerUnknownFields {
fn visit_unknown_fields_within(
&self,
parent_path: Path<'_>,
field_name: &str,
visitor: &mut impl Visitor,
) -> ControlFlow<()>;
}
impl VisitContainerUnknownFields for SpecialFields {
fn visit_unknown_fields_within(
&self,
parent_path: Path<'_>,
_field_name: &str,
visitor: &mut impl Visitor,
) -> ControlFlow<()> {
debug_assert_eq!(_field_name, "special_fields");
visitor.unknown_fields(parent_path, self.unknown_fields())
}
}
impl<U: VisitUnknownFields> VisitContainerUnknownFields for Vec<U> {
fn visit_unknown_fields_within(
&self,
parent_path: Path<'_>,
field_name: &str,
visitor: &mut impl Visitor,
) -> ControlFlow<()> {
for (index, item) in self.iter().enumerate() {
let path = Path::Branch {
parent: &parent_path,
field_name,
part: Part::Repeated { index },
};
item.visit_unknown_fields(path, visitor)?
}
ControlFlow::Continue(())
}
}
impl<K, V: VisitUnknownFields> VisitContainerUnknownFields for HashMap<K, V>
where
for<'a> &'a K: Into<MapKey<'a>>,
{
fn visit_unknown_fields_within(
&self,
parent_path: Path<'_>,
field_name: &str,
visitor: &mut impl Visitor,
) -> ControlFlow<()> {
for (key, value) in self.iter() {
let path = Path::Branch {
parent: &parent_path,
field_name,
part: Part::MapValue { key: key.into() },
};
value.visit_unknown_fields(path, visitor)?
}
ControlFlow::Continue(())
}
}
impl<U: VisitUnknownFields> VisitContainerUnknownFields for Option<U> {
fn visit_unknown_fields_within(
&self,
parent_path: Path<'_>,
field_name: &str,
visitor: &mut impl Visitor,
) -> ControlFlow<()> {
self.as_ref().map_or(ControlFlow::Continue(()), |inner| {
inner.visit_unknown_fields(
Path::Branch {
parent: &parent_path,
field_name,
part: Part::Field,
},
visitor,
)
})
}
}
macro_rules! no_unknown_fields {
($type:path) => {
impl VisitUnknownFields for $type {
fn visit_unknown_fields(
&self,
_path: Path<'_>,
_visitor: &mut impl Visitor,
) -> ControlFlow<()> {
ControlFlow::Continue(())
}
}
};
}
no_unknown_fields!(u8);
no_unknown_fields!(u32);
no_unknown_fields!(u64);
no_unknown_fields!(i32);
no_unknown_fields!(i64);
no_unknown_fields!(bool);
no_unknown_fields!(String);
no_unknown_fields!(Vec<u8>);