0
0
mirror of https://github.com/signalapp/Signal-Server.git synced 2024-09-20 03:52:16 +02:00

Add spam report token support to ReportedMessageListener

This commit is contained in:
Jon Chambers 2023-01-30 11:33:39 -05:00 committed by Jon Chambers
parent 00e08b8402
commit 4a2768b81d
6 changed files with 53 additions and 23 deletions

View File

@ -643,13 +643,14 @@ public class MessageController {
UUID spamReporterUuid = auth.getAccount().getUuid();
// spam report token is optional, but if provided ensure it is valid base64.
@Nullable final byte[] spamReportToken = spamReport != null ? spamReport.token() : null;
final Optional<byte[]> maybeSpamReportToken =
spamReport != null ? Optional.of(spamReport.token()) : Optional.empty();
// fire-and-forget: we don't want to block the response on this action.
CompletableFuture<Boolean> ignored =
reportSpamTokenHandler.handle(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, spamReportToken);
reportSpamTokenHandler.handle(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, maybeSpamReportToken.orElse(null));
reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid);
reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, maybeSpamReportToken);
return Response.status(Status.ACCEPTED)
.build();

View File

@ -9,6 +9,7 @@ import static com.codahale.metrics.MetricRegistry.name;
import io.micrometer.core.instrument.Metrics;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import net.logstash.logback.marker.Markers;
import org.slf4j.Logger;
@ -35,7 +36,9 @@ public class ReportedMessageMetricsListener implements ReportedMessageListener {
}
@Override
public void handleMessageReported(final String sourceNumber, final UUID messageGuid, final UUID reporterUuid) {
public void handleMessageReported(final String sourceNumber, final UUID messageGuid, final UUID reporterUuid,
final Optional<byte[]> reportSpamToken) {
final String sourceCountryCode = Util.getCountryCode(sourceNumber);
Metrics.counter(REPORTED_COUNTER_NAME, COUNTRY_CODE_TAG_NAME, sourceCountryCode).increment();

View File

@ -15,8 +15,10 @@ import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import io.micrometer.core.instrument.Metrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
@ -29,6 +31,10 @@ public class ReportMessageManager {
private final List<ReportedMessageListener> reportedMessageListeners = new ArrayList<>();
private static final String REPORT_MESSAGE_COUNTER_NAME = MetricsUtil.name(ReportMessageManager.class);
private static final String FOUND_MESSAGE_TAG = "foundMessage";
private static final String TOKEN_PRESENT_TAG = "hasReportSpamToken";
private static final Logger logger = LoggerFactory.getLogger(ReportMessageManager.class);
public ReportMessageManager(final ReportMessageDynamoDb reportMessageDynamoDb,
@ -56,12 +62,21 @@ public class ReportMessageManager {
}
}
public void report(Optional<String> sourceNumber, Optional<UUID> sourceAci, Optional<UUID> sourcePni,
UUID messageGuid, UUID reporterUuid) {
public void report(final Optional<String> sourceNumber,
final Optional<UUID> sourceAci,
final Optional<UUID> sourcePni,
final UUID messageGuid,
final UUID reporterUuid,
final Optional<byte[]> reportSpamToken) {
final boolean found = sourceAci.map(uuid -> reportMessageDynamoDb.remove(hash(messageGuid, uuid.toString())))
.orElse(false);
Metrics.counter(REPORT_MESSAGE_COUNTER_NAME,
FOUND_MESSAGE_TAG, String.valueOf(found),
TOKEN_PRESENT_TAG, String.valueOf(reportSpamToken.isPresent()))
.increment();
if (found) {
rateLimitCluster.useCluster(connection -> {
sourcePni.ifPresent(pni -> {
@ -80,7 +95,7 @@ public class ReportMessageManager {
sourceNumber.ifPresent(number ->
reportedMessageListeners.forEach(listener -> {
try {
listener.handleMessageReported(number, messageGuid, reporterUuid);
listener.handleMessageReported(number, messageGuid, reporterUuid, reportSpamToken);
} catch (final Exception e) {
logger.error("Failed to notify listener of reported message", e);
}

View File

@ -5,9 +5,10 @@
package org.whispersystems.textsecuregcm.storage;
import java.util.Optional;
import java.util.UUID;
public interface ReportedMessageListener {
void handleMessageReported(String sourceNumber, UUID messageGuid, UUID reporterUuid);
void handleMessageReported(String sourceNumber, UUID messageGuid, UUID reporterUuid, Optional<byte[]> reportSpamToken);
}

View File

@ -14,6 +14,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyString;
@ -67,6 +68,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenHandler;
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
@ -634,7 +636,7 @@ class MessageControllerTest {
assertThat(response.getStatus(), is(equalTo(202)));
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
messageGuid, AuthHelper.VALID_UUID);
messageGuid, AuthHelper.VALID_UUID, Optional.empty());
verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class));
verify(accountsManager, never()).getPhoneNumberIdentifier(anyString());
@ -651,7 +653,7 @@ class MessageControllerTest {
assertThat(response.getStatus(), is(equalTo(202)));
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
messageGuid, AuthHelper.VALID_UUID);
messageGuid, AuthHelper.VALID_UUID, Optional.empty());
}
@Test
@ -681,7 +683,7 @@ class MessageControllerTest {
assertThat(response.getStatus(), is(equalTo(202)));
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
messageGuid, AuthHelper.VALID_UUID);
messageGuid, AuthHelper.VALID_UUID, Optional.empty());
verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class));
verify(accountsManager, never()).getPhoneNumberIdentifier(anyString());
@ -699,7 +701,7 @@ class MessageControllerTest {
assertThat(response.getStatus(), is(equalTo(202)));
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
messageGuid, AuthHelper.VALID_UUID);
messageGuid, AuthHelper.VALID_UUID, Optional.empty());
}
@Test
@ -733,8 +735,12 @@ class MessageControllerTest {
assertThat(response.getStatus(), is(equalTo(202)));
verify(REPORT_SPAM_TOKEN_HANDLER).handle(any(), any(), any(), any(), any(), captor.capture());
assertArrayEquals(new byte[3], captor.getValue());
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
messageGuid, AuthHelper.VALID_UUID);
verify(reportMessageManager).report(eq(Optional.of(senderNumber)),
eq(Optional.of(senderAci)),
eq(Optional.of(senderPni)),
eq(messageGuid),
eq(AuthHelper.VALID_UUID),
argThat(maybeBytes -> maybeBytes.map(bytes -> Arrays.equals(bytes, new byte[3])).orElse(false)));
verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class));
verify(accountsManager, never()).getPhoneNumberIdentifier(anyString());
when(accountsManager.getByAccountIdentifier(senderAci)).thenReturn(Optional.empty());
@ -754,8 +760,12 @@ class MessageControllerTest {
assertThat(response.getStatus(), is(equalTo(202)));
verify(REPORT_SPAM_TOKEN_HANDLER).handle(any(), any(), any(), any(), any(), captor.capture());
assertArrayEquals(new byte[5], captor.getValue());
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
messageGuid, AuthHelper.VALID_UUID);
verify(reportMessageManager).report(eq(Optional.of(senderNumber)),
eq(Optional.of(senderAci)),
eq(Optional.of(senderPni)),
eq(messageGuid),
eq(AuthHelper.VALID_UUID),
argThat(maybeBytes -> maybeBytes.map(bytes -> Arrays.equals(bytes, new byte[5])).orElse(false)));
}
@Test

View File

@ -81,16 +81,16 @@ class ReportMessageManagerTest {
when(reportMessageDynamoDb.remove(any())).thenReturn(false);
reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), messageGuid,
reporterUuid);
reporterUuid, Optional.empty());
assertEquals(0, reportMessageManager.getRecentReportCount(sourceAccount));
when(reportMessageDynamoDb.remove(any())).thenReturn(true);
reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), messageGuid,
reporterUuid);
reporterUuid, Optional.empty());
assertEquals(1, reportMessageManager.getRecentReportCount(sourceAccount));
verify(listener).handleMessageReported(sourceNumber, messageGuid, reporterUuid);
verify(listener).handleMessageReported(sourceNumber, messageGuid, reporterUuid, Optional.empty());
}
@Test
@ -100,7 +100,7 @@ class ReportMessageManagerTest {
for (int i = 0; i < 100; i++) {
reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni),
messageGuid, UUID.randomUUID());
messageGuid, UUID.randomUUID(), Optional.empty());
}
assertTrue(reportMessageManager.getRecentReportCount(sourceAccount) > 10);
@ -114,7 +114,7 @@ class ReportMessageManagerTest {
for (int i = 0; i < 100; i++) {
reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni),
messageGuid,
reporterUuid);
reporterUuid, Optional.empty());
}
assertEquals(1, reportMessageManager.getRecentReportCount(sourceAccount));
@ -127,11 +127,11 @@ class ReportMessageManagerTest {
for (int i = 0; i < 100; i++) {
reportMessageManager.report(Optional.empty(), Optional.of(sourceAci), Optional.of(sourcePni),
messageGuid, UUID.randomUUID());
messageGuid, UUID.randomUUID(), Optional.empty());
}
reportMessageManager.report(Optional.empty(), Optional.of(sourceAci), Optional.empty(),
messageGuid, UUID.randomUUID());
messageGuid, UUID.randomUUID(), Optional.empty());
final int recentReportCount = reportMessageManager.getRecentReportCount(sourceAccount);
assertTrue(recentReportCount > 10);