0
0
mirror of https://github.com/signalapp/libsignal.git synced 2024-09-20 03:52:17 +02:00

Implement protobuf unknown field detection using field descriptors

Implement protobuf unknown field search by walking the tree of field
descriptors.
This commit is contained in:
Alex Konradi 2024-01-10 17:08:13 -05:00 committed by GitHub
parent 897051d97c
commit 3afe5bfe58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 577 additions and 4 deletions

View File

@ -45,4 +45,5 @@ test-case = "3.3.1"
test-log = "0.2.14"
[build-dependencies]
protobuf = "3.3.0"
protobuf-codegen = "3.3.0"

View File

@ -4,7 +4,9 @@
//
fn main() {
let protos = ["src/proto/backup.proto"];
const PROTOS: &[&str] = &["src/proto/backup.proto", "src/proto/test.proto"];
const PROTOS_DIR: &str = "protos";
protobuf_codegen::Codegen::new()
.protoc()
.protoc_extra_arg(
@ -14,10 +16,19 @@ fn main() {
"--experimental_allow_proto3_optional",
)
.include("src")
.inputs(protos)
.cargo_out_dir("protos")
.inputs(PROTOS)
.cargo_out_dir(PROTOS_DIR)
.run_from_script();
for proto in &protos {
// Mark the test.proto module as test-only.
let out_mod_rs = format!("{}/{PROTOS_DIR}/mod.rs", std::env::var("OUT_DIR").unwrap());
let mut contents = std::fs::read_to_string(&out_mod_rs).unwrap();
let insert_pos = contents.find("pub mod test;").unwrap_or(0);
contents.insert_str(insert_pos, "\n#[cfg(test)] // only for testing\n");
std::fs::write(out_mod_rs, contents).unwrap();
for proto in PROTOS {
println!("cargo:rerun-if-changed={}", proto);
}
}

View File

@ -3,6 +3,8 @@
// SPDX-License-Identifier: AGPL-3.0-only
//
pub(crate) mod unknown;
include!(concat!(env!("OUT_DIR"), "/protos/mod.rs"));
macro_rules! impl_into_oneof {

View File

@ -0,0 +1,80 @@
syntax = "proto3";
package signal.backup.test;
option java_package = "org.thoughtcrime.securesms.backup.v2.proto.test";
enum TestEnum {
ZERO = 0;
ONE = 1;
TWO = 2;
}
enum TestEnumWithExtraVariants {
ZERO_EXTRA_VARIANTS = 0;
ONE_EXTRA_VARIANTS = 1;
TWO_EXTRA_VARIANTS = 2;
EXTRA_THREE = 3;
EXTRA_FOUR = 4;
}
message TestMessage {
oneof oneof {
bool oneof_bool = 600;
TestMessage oneof_message = 601;
}
string string = 700;
int64 int64 = 710;
repeated TestMessage repeated_message = 720;
bytes bytes = 730;
repeated uint64 repeated_uint64 = 740;
TestEnum enum = 750;
optional TestMessage nested_message = 760;
map<string, TestMessage> map = 770;
}
message TestMessageWithExtraFields {
oneof oneof {
bool oneof_bool = 600;
TestMessageWithExtraFields oneof_message = 601;
TestMessageWithExtraFields oneof_extra_message = 610;
string oneof_extra_string = 611;
int64 oneof_extra_int64 = 612;
}
string string = 700;
string extra_string = 701;
int64 int64 = 710;
int64 extra_int64 = 711;
repeated TestMessageWithExtraFields repeated_message = 720;
repeated TestMessageWithExtraFields extra_repeated_message = 721;
bytes bytes = 730;
bytes extra_bytes = 731;
repeated uint64 repeated_uint64 = 740;
repeated uint64 extra_repeated_uint64 = 741;
// Intentionally use the wrong enum type for the same field.
TestEnumWithExtraVariants enum = 750;
TestEnumWithExtraVariants extra_enum = 751;
optional TestMessageWithExtraFields nested_message = 760;
optional TestMessageWithExtraFields extra_nested_message = 761;
map<string, TestMessageWithExtraFields> map = 770;
map<string, TestMessageWithExtraFields> extra_map = 771;
}

View File

@ -0,0 +1,354 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
//! Protobuf unknown field searching.
use std::ops::ControlFlow;
use protobuf::MessageFull;
mod visit_dyn;
/// Protobuf message path component.
#[derive(Clone, Debug, Eq, PartialEq, displaydoc::Display)]
pub enum PathPart {
/// {field_name}[{index}]
Repeated { field_name: String, index: usize },
/// {field_name}
Field { field_name: String },
/// {field_name}[{key}]
MapValue { field_name: String, key: String },
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum UnknownValue {
EnumValue { number: i32 },
Field { tag: u32 },
}
/// A path within a protobuf message.
///
/// Implemented as a singly linked list to avoid allocation.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum Path<'a> {
Root,
Branch {
parent: &'a Path<'a>,
field_name: &'a str,
part: Part<'a>,
},
}
/// The part of a logical field that is being referenced.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum Part<'a> {
Field,
MapValue { key: MapKey<'a> },
Repeated { index: usize },
}
/// Key in a protobuf map field.
#[derive(Copy, Clone, Debug, Eq, PartialEq, displaydoc::Display)]
enum MapKey<'a> {
/// {0}
U32(u32),
/// {0}
U64(u64),
/// {0}
I32(i32),
/// {0}
I64(i64),
/// {0}
Bool(bool),
/// {0:?}
String(&'a str),
}
impl Path<'_> {
fn owned_parts(&self) -> Vec<PathPart> {
let mut head = self;
let mut output = Vec::new();
// Standard linked-list traversal.
while let Path::Branch {
parent,
part,
field_name,
} = head
{
let field_name = field_name.to_string();
output.push(PathPart::from_part(part, field_name));
head = parent;
}
// Since each `Path` points at its parent, the list is lowest-to-highest
// path part before reversal.
output.reverse();
output
}
}
impl PathPart {
fn from_part(part: &Part<'_>, field_name: String) -> Self {
match part {
Part::Field => Self::Field { field_name },
Part::MapValue { key } => Self::MapValue {
field_name,
key: key.to_string(),
},
Part::Repeated { index } => Self::Repeated {
field_name,
index: *index,
},
}
}
}
/// Visitor for unknown fields on a [`protobuf::Message`].
pub(super) trait VisitUnknownFields {
/// Calls the visitor for each unknown field in the message.
fn visit_unknown_fields<F: UnknownFieldVisitor>(&self, visitor: F);
}
/// Convenience "alias" for a callable visitor.
pub trait UnknownFieldVisitor: FnMut(Vec<PathPart>, UnknownValue) -> ControlFlow<()> {}
impl<F: FnMut(Vec<PathPart>, UnknownValue) -> ControlFlow<()>> UnknownFieldVisitor for F {}
impl<M: MessageFull> VisitUnknownFields for M {
fn visit_unknown_fields<F: UnknownFieldVisitor>(&self, mut visitor: F) {
// Currently implemented using dynamic traversal of protobuf
// descriptors.
// TODO: evaulate speed and code size versus statically-dispatched
// traversal.
let _: ControlFlow<()> = visit_dyn::visit_unknown_fields(self, Path::Root, &mut visitor);
}
}
/// Extension trait for [`VisitUnknownFields`] with convenience methods.
pub(super) trait VisitUnknownFieldsExt {
fn has_unknown_fields(&self) -> bool;
fn collect_unknown_fields(&self) -> Vec<(Vec<PathPart>, UnknownValue)>;
fn find_unknown_field(&self) -> Option<(Vec<PathPart>, UnknownValue)>;
}
impl<V: VisitUnknownFields> VisitUnknownFieldsExt for V {
fn has_unknown_fields(&self) -> bool {
self.find_unknown_field().is_some()
}
fn collect_unknown_fields(&self) -> Vec<(Vec<PathPart>, UnknownValue)> {
let mut found = Vec::new();
self.visit_unknown_fields(|path, value| {
found.push((path, value));
ControlFlow::Continue(())
});
found
}
fn find_unknown_field(&self) -> Option<(Vec<PathPart>, UnknownValue)> {
let mut found = None;
self.visit_unknown_fields(|path, value| {
found = Some((path, value));
ControlFlow::Break(())
});
found
}
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use protobuf::{Message, MessageFull};
use test_case::test_case;
use super::*;
use crate::proto::test as proto;
trait ProtoWireCast {
fn wire_cast_as<M: Message>(self) -> M;
}
impl<S: Message> ProtoWireCast for S {
fn wire_cast_as<M: Message>(self) -> M {
let mut bytes = Vec::new();
self.write_to_vec(&mut bytes).expect("can serialize");
M::parse_from_bytes(&bytes).expect("can deserialize")
}
}
const FAKE_BYTES: [u8; 5] = *b"abcde";
const FAKE_STRING: &str = "fghij";
const FAKE_INT64: i64 = 49582945;
const FAKE_REPEATED_UINT64: [u64; 2] = [42, 85];
const FAKE_ONEOF: proto::test_message::Oneof = proto::test_message::Oneof::OneofBool(false);
const FAKE_ENUM: proto::TestEnum = proto::TestEnum::TWO;
impl proto::TestMessage {
fn fake_data() -> Self {
Self {
bytes: FAKE_BYTES.into(),
string: FAKE_STRING.into(),
int64: FAKE_INT64,
repeated_message: vec![proto::TestMessage::default(); 3],
repeated_uint64: FAKE_REPEATED_UINT64.into(),
enum_: FAKE_ENUM.into(),
oneof: Some(FAKE_ONEOF),
nested_message: Some(proto::TestMessage::default()).into(),
map: HashMap::from([("key".to_string(), proto::TestMessage::default())]),
special_fields: Default::default(),
}
}
}
fn never_visits(_: Vec<PathPart>, _: UnknownValue) -> ControlFlow<()> {
unreachable!("unexpectedly visited")
}
#[test_case(proto::TestMessage::default())]
#[test_case(proto::TestMessage::fake_data())]
#[test_case(proto::TestMessage::fake_data().wire_cast_as::<proto::TestMessage>())]
#[test_case(proto::TestMessage::fake_data().wire_cast_as::<proto::TestMessageWithExtraFields>())]
fn no_extra_fields(proto: impl MessageFull) {
let m = never_visits;
proto.visit_unknown_fields(m);
}
macro_rules! modifier {
($name:ident, $field:ident = $value:expr) => {
fn $name(target: &mut proto::TestMessageWithExtraFields) {
#[allow(unused)]
use proto::test_message_with_extra_fields::Oneof;
target.$field = $value;
}
};
}
modifier!(oneof_extra_int, oneof = Some(Oneof::OneofExtraInt64(32)));
modifier!(
oneof_extra_string,
oneof = Some(Oneof::OneofExtraString("asdf".into()))
);
modifier!(
oneof_extra_message,
oneof = Some(Oneof::OneofExtraMessage(Box::default()))
);
modifier!(extra_string, extra_string = FAKE_STRING.into());
modifier!(extra_bytes, extra_bytes = FAKE_BYTES.into());
modifier!(extra_int64, extra_int64 = FAKE_INT64);
modifier!(
extra_repeated_message,
extra_repeated_message = vec![proto::TestMessageWithExtraFields::default(); 4]
);
modifier!(
extra_repeated_uint64,
extra_repeated_uint64 = FAKE_REPEATED_UINT64.into()
);
modifier!(
extra_enum,
extra_enum = proto::TestEnumWithExtraVariants::TWO_EXTRA_VARIANTS.into()
);
modifier!(
extra_nested,
extra_nested_message = Some(proto::TestMessageWithExtraFields::default()).into()
);
modifier!(
extra_map,
extra_map = HashMap::from([(
"extra key".to_string(),
proto::TestMessageWithExtraFields::default()
)])
);
#[test_case(oneof_extra_int, UnknownValue::Field {tag: 612})]
#[test_case(oneof_extra_string, UnknownValue::Field {tag: 611})]
#[test_case(oneof_extra_message, UnknownValue::Field {tag: 610})]
#[test_case(extra_string, UnknownValue::Field {tag: 701})]
#[test_case(extra_bytes, UnknownValue::Field {tag: 731})]
#[test_case(extra_int64, UnknownValue::Field {tag: 711})]
#[test_case(extra_repeated_message, UnknownValue::Field {tag: 721})]
#[test_case(extra_repeated_uint64, UnknownValue::Field {tag: 741})]
#[test_case(extra_enum, UnknownValue::Field {tag: 751})]
#[test_case(extra_nested, UnknownValue::Field {tag: 761})]
#[test_case(extra_map, UnknownValue::Field {tag: 771})]
fn has_unknown_fields_top_level(
modifier: fn(&mut proto::TestMessageWithExtraFields),
expected_value: UnknownValue,
) {
let mut message =
proto::TestMessage::fake_data().wire_cast_as::<proto::TestMessageWithExtraFields>();
modifier(&mut message);
let (path, value) = message
.wire_cast_as::<proto::TestMessage>()
.find_unknown_field()
.expect("has unknown");
assert_eq!(value, expected_value);
assert_eq!(path, &[]);
}
#[test]
fn unknown_fields_in_nested_message() {
let message = proto::TestMessageWithExtraFields {
nested_message: Some(proto::TestMessageWithExtraFields {
repeated_message: vec![
proto::TestMessageWithExtraFields::default(),
proto::TestMessageWithExtraFields {
extra_int64: FAKE_INT64,
..Default::default()
},
],
..Default::default()
})
.into(),
map: HashMap::from([(
"map_key".to_string(),
proto::TestMessageWithExtraFields {
oneof: Some(proto::test_message_with_extra_fields::Oneof::OneofMessage(
Box::new(proto::TestMessageWithExtraFields {
enum_: proto::TestEnumWithExtraVariants::EXTRA_THREE.into(),
..Default::default()
}),
)),
..Default::default()
},
)]),
..Default::default()
};
let message: proto::TestMessage = message.wire_cast_as();
const EXPECTED_UNKNOWN: [(&str, UnknownValue); 2] = [
(
"nested_message.repeated_message[1]",
UnknownValue::Field { tag: 711 },
),
(
"map[\"map_key\"].oneof_message.enum",
UnknownValue::EnumValue { number: 3 },
),
];
let found: Vec<_> = message
.collect_unknown_fields()
.into_iter()
.map(|(key, value)| {
let key = key
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(".");
(key, value)
})
.collect();
assert_eq!(
HashMap::from_iter(found.iter().map(|(k, v)| (k.as_str(), *v))),
HashMap::from(EXPECTED_UNKNOWN)
);
}
}

View File

@ -0,0 +1,125 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
//! Unknown field searching via dynamic traversal of protubuf message
//! descriptors.
use std::ops::ControlFlow;
use protobuf::reflect::{ReflectFieldRef, ReflectValueRef};
use protobuf::MessageDyn;
use crate::proto::unknown::{MapKey, Part, Path, UnknownFieldVisitor, UnknownValue};
pub(super) fn visit_unknown_fields(
message: &dyn MessageDyn,
path: Path<'_>,
// It seems like this could be just `impl UnknownFieldVisitor`, but
// because this function is indirectly recursive and fans out, making it
// an owned type (and recursing with `&mut visitor`) causes the compiler
// to fail to determine whether `&mut &mut .... &mut impl
// UnknownFieldVisitor` implements `UnknownFieldVisitor`.
visitor: &mut impl UnknownFieldVisitor,
) -> ControlFlow<()> {
for (tag, _value) in message.unknown_fields_dyn() {
visitor(path.owned_parts(), UnknownValue::Field { tag })?
}
for field in message.descriptor_dyn().fields() {
visit_child_unknown_fields(field.get_reflect(message), path, field.name(), visitor)?
}
ControlFlow::Continue(())
}
fn visit_child_unknown_fields<'s>(
field: ReflectFieldRef<'s>,
parent_path: Path<'s>,
field_name: &'s str,
visitor: &mut impl UnknownFieldVisitor,
) -> ControlFlow<()> {
let make_path = |part| Path::Branch {
parent: &parent_path,
field_name,
part,
};
match field {
ReflectFieldRef::Optional(value) => {
let Some(value) = value.value() else {
return ControlFlow::Continue(());
};
visit_field(value, visitor, make_path(Part::Field))
}
ReflectFieldRef::Repeated(values) => {
for (index, value) in values.into_iter().enumerate() {
visit_field(value, visitor, make_path(Part::Repeated { index }))?;
}
ControlFlow::Continue(())
}
ReflectFieldRef::Map(values) => {
for (key, value) in &values {
let key = key.into();
visit_field(value, visitor, make_path(Part::MapValue { key }))?;
}
ControlFlow::Continue(())
}
}
}
fn visit_field<'s>(
value: ReflectValueRef<'s>,
visitor: &mut impl UnknownFieldVisitor,
path: Path<'s>,
) -> ControlFlow<()> {
match value {
ReflectValueRef::U32(_)
| ReflectValueRef::U64(_)
| ReflectValueRef::I32(_)
| ReflectValueRef::I64(_)
| ReflectValueRef::F32(_)
| ReflectValueRef::F64(_)
| ReflectValueRef::Bool(_)
| ReflectValueRef::String(_)
| ReflectValueRef::Bytes(_) => ControlFlow::Continue(()),
ReflectValueRef::Enum(descriptor, number) => {
if descriptor.value_by_number(number).is_none() {
visitor(path.owned_parts(), UnknownValue::EnumValue { number })
} else {
ControlFlow::Continue(())
}
}
ReflectValueRef::Message(message) => {
let message: &dyn MessageDyn = &*message;
visit_unknown_fields(message, path, visitor)
}
}
}
impl<'a> From<ReflectValueRef<'a>> for MapKey<'a> {
fn from(value: ReflectValueRef<'a>) -> Self {
match value {
ReflectValueRef::U32(v) => Self::U32(v),
ReflectValueRef::U64(v) => Self::U64(v),
ReflectValueRef::I32(v) => Self::I32(v),
ReflectValueRef::I64(v) => Self::I64(v),
ReflectValueRef::Bool(v) => Self::Bool(v),
ReflectValueRef::String(v) => Self::String(v),
v @ ReflectValueRef::F32(_)
| v @ ReflectValueRef::F64(_)
| v @ ReflectValueRef::Bytes(_)
| v @ ReflectValueRef::Enum(_, _)
| v @ ReflectValueRef::Message(_) => {
// Per the protobuf docs:
// > key_type can be any integral or string type (so, any scalar
// > type except for floating point types and bytes). Note that
// > enum is not a valid key_type."
unreachable!("unexpected key {v}")
}
}
}
}