diff --git a/.gitignore b/.gitignore index 6e2d65f2..721bacd1 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ deployer.log .project .classpath .settings +.DS_Store \ No newline at end of file diff --git a/service/config/sample.yml b/service/config/sample.yml index 84e91a74..e7a32878 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -450,3 +450,27 @@ turn: linkDevice: secret: secret://linkDevice.secret + +maxmindCityDatabase: + s3Region: a-region + s3Bucket: a-bucket + objectKey: an-object.tar.gz + maxSize: 32777216 + +callingTurnDnsRecords: + s3Region: a-region + s3Bucket: a-bucket + objectKey: an-object.tar.gz + maxSize: 32777216 + +callingTurnPerformanceTable: + s3Region: a-region + s3Bucket: a-bucket + objectKey: an-object.tar.gz + maxSize: 32777216 + +callingTurnManualTable: + s3Region: a-region + s3Bucket: a-bucket + objectKey: an-object.tar.gz + maxSize: 32777216 diff --git a/service/pom.xml b/service/pom.xml index b4584641..f971a6d5 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -204,6 +204,17 @@ org.apache.commons commons-csv + + org.apache.commons + commons-compress + 1.24.0 + + + + com.maxmind.geoip2 + geoip2 + 4.2.0 + com.google.firebase diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index dbca46e6..3c60e922 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -38,6 +38,7 @@ import org.whispersystems.textsecuregcm.configuration.LinkDeviceSecretConfigurat import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration; import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalityEstimatorConfiguration; import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration; +import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration; import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration; import org.whispersystems.textsecuregcm.configuration.RecaptchaConfiguration; @@ -322,11 +323,31 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private VirtualThreadConfiguration virtualThreadConfiguration = new VirtualThreadConfiguration(Duration.ofMillis(1)); + + @Valid + @NotNull + @JsonProperty + private MonitoredS3ObjectConfiguration maxmindCityDatabase; + + @Valid + @NotNull + @JsonProperty + private MonitoredS3ObjectConfiguration callingTurnDnsRecords; + + @Valid + @NotNull + @JsonProperty + private MonitoredS3ObjectConfiguration callingTurnPerformanceTable; + + @Valid + @NotNull + @JsonProperty + private MonitoredS3ObjectConfiguration callingTurnManualTable; + public TlsKeyStoreConfiguration getTlsKeyStoreConfiguration() { return tlsKeyStore; } - public StripeConfiguration getStripe() { return stripe; } @@ -537,4 +558,20 @@ public class WhisperServerConfiguration extends Configuration { public VirtualThreadConfiguration getVirtualThreadConfiguration() { return virtualThreadConfiguration; } + + public MonitoredS3ObjectConfiguration getMaxmindCityDatabase() { + return maxmindCityDatabase; + } + + public MonitoredS3ObjectConfiguration getCallingTurnDnsRecords() { + return callingTurnDnsRecords; + } + + public MonitoredS3ObjectConfiguration getCallingTurnPerformanceTable() { + return callingTurnPerformanceTable; + } + + public MonitoredS3ObjectConfiguration getCallingTurnManualTable() { + return callingTurnManualTable; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 5b9922c9..48d3979d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -78,6 +78,10 @@ import org.whispersystems.textsecuregcm.backup.Cdn3BackupCredentialGenerator; import org.whispersystems.textsecuregcm.backup.Cdn3RemoteStorageManager; import org.whispersystems.textsecuregcm.badges.ConfiguredProfileBadgeConverter; import org.whispersystems.textsecuregcm.badges.ResourceBundleLevelTranslator; +import org.whispersystems.textsecuregcm.calls.routing.CallDnsRecordsManager; +import org.whispersystems.textsecuregcm.calls.routing.CallRoutingTableManager; +import org.whispersystems.textsecuregcm.calls.routing.DynamicConfigTurnRouter; +import org.whispersystems.textsecuregcm.calls.routing.TurnCallRouter; import org.whispersystems.textsecuregcm.captcha.CaptchaChecker; import org.whispersystems.textsecuregcm.captcha.HCaptchaClient; import org.whispersystems.textsecuregcm.captcha.RecaptchaClient; @@ -93,6 +97,7 @@ import org.whispersystems.textsecuregcm.controllers.ArtController; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV4; +import org.whispersystems.textsecuregcm.controllers.CallRoutingController; import org.whispersystems.textsecuregcm.controllers.CallLinkController; import org.whispersystems.textsecuregcm.controllers.CertificateController; import org.whispersystems.textsecuregcm.controllers.ChallengeController; @@ -121,6 +126,7 @@ import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; +import org.whispersystems.textsecuregcm.geo.MaxMindDatabaseManager; import org.whispersystems.textsecuregcm.grpc.AcceptLanguageInterceptor; import org.whispersystems.textsecuregcm.grpc.AccountsAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.AccountsGrpcService; @@ -452,6 +458,8 @@ public class WhisperServerService extends Application urls) { +public record TurnToken(String username, String password, List urls, List urlsWithIps, String hostname) { + public TurnToken(String username, String password, List urls) { + this(username, password, urls, null, null); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/TurnTokenGenerator.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/TurnTokenGenerator.java index 6f2b378d..f2fb9308 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/TurnTokenGenerator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/TurnTokenGenerator.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.auth; +import org.whispersystems.textsecuregcm.calls.routing.TurnServerOptions; import org.whispersystems.textsecuregcm.configuration.TurnUriConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicTurnConfiguration; @@ -41,8 +42,15 @@ public class TurnTokenGenerator { } public TurnToken generate(final UUID aci) { + return generateToken(null, null, urls(aci)); + } + + public TurnToken generateWithTurnServerOptions(TurnServerOptions options) { + return generateToken(options.hostname(), options.urlsWithIps(), options.urlsWithHostname()); + } + + private TurnToken generateToken(String hostname, List urlsWithIps, List urlsWithHostname) { try { - final List urls = urls(aci); final Mac mac = Mac.getInstance(ALGORITHM); final long validUntilSeconds = Instant.now().plus(Duration.ofDays(1)).getEpochSecond(); final long user = Util.ensureNonNegativeInt(new SecureRandom().nextInt()); @@ -51,7 +59,7 @@ public class TurnTokenGenerator { mac.init(new SecretKeySpec(turnSecret, ALGORITHM)); final String password = Base64.getEncoder().encodeToString(mac.doFinal(userTime.getBytes())); - return new TurnToken(userTime, password, urls); + return new TurnToken(userTime, password, urlsWithHostname, urlsWithIps, hostname); } catch (final NoSuchAlgorithmException | InvalidKeyException e) { throw new AssertionError(e); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallDnsRecords.java b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallDnsRecords.java new file mode 100644 index 00000000..0d20fea5 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallDnsRecords.java @@ -0,0 +1,34 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import javax.validation.constraints.NotNull; +import java.net.InetAddress; +import java.util.List; +import java.util.Map; + +public record CallDnsRecords( + @NotNull + Map> aByRegion, + @NotNull + Map> aaaaByRegion +) { + public String getSummary() { + int numARecords = aByRegion.values().stream().mapToInt(List::size).sum(); + int numAAAARecords = aaaaByRegion.values().stream().mapToInt(List::size).sum(); + return String.format( + "(A records, %s regions, %s records), (AAAA records, %s regions, %s records)", + aByRegion.size(), + numARecords, + aaaaByRegion.size(), + numAAAARecords + ); + } + + public static CallDnsRecords empty() { + return new CallDnsRecords(Map.of(), Map.of()); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallDnsRecordsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallDnsRecordsManager.java new file mode 100644 index 00000000..6529ab27 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallDnsRecordsManager.java @@ -0,0 +1,92 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import com.fasterxml.jackson.core.StreamReadFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.json.JsonMapper; +import io.dropwizard.lifecycle.Managed; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; +import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.s3.S3ObjectMonitor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import javax.annotation.Nonnull; +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +public class CallDnsRecordsManager implements Supplier, Managed { + + private final S3ObjectMonitor objectMonitor; + + private final AtomicReference callDnsRecords = new AtomicReference<>(); + + private final Timer refreshTimer; + + private static final Logger log = LoggerFactory.getLogger(CallDnsRecordsManager.class); + + private static final ObjectMapper objectMapper = JsonMapper.builder() + .enable(StreamReadFeature.INCLUDE_SOURCE_IN_LOCATION) + .build(); + + public CallDnsRecordsManager( + @Nonnull final ScheduledExecutorService executorService, + @Nonnull final MonitoredS3ObjectConfiguration configuration + ){ + this.objectMonitor = new S3ObjectMonitor( + configuration.s3Region(), + configuration.s3Bucket(), + configuration.objectKey(), + configuration.maxSize(), + executorService, + configuration.refreshInterval(), + this::handleDatabaseChanged + ); + + this.callDnsRecords.set(CallDnsRecords.empty()); + this.refreshTimer = Metrics.timer(MetricsUtil.name(CallDnsRecordsManager.class), "refresh"); + } + + private void handleDatabaseChanged(final InputStream inputStream) { + refreshTimer.record(() -> { + try (final InputStream bufferedInputStream = new BufferedInputStream(inputStream)) { + final CallDnsRecords newRecords = parseRecords(bufferedInputStream); + final CallDnsRecords oldRecords = callDnsRecords.getAndSet(newRecords); + log.info("Replaced dns records, old summary=[{}], new summary=[{}]", oldRecords != null ? oldRecords.getSummary() : "null", newRecords); + } catch (final IOException e) { + log.error("Failed to load Call DNS Records"); + } + }); + } + + static CallDnsRecords parseRecords(InputStream inputStream) throws IOException { + return objectMapper.readValue(inputStream, CallDnsRecords.class); + } + + @Override + public void start() throws Exception { + Managed.super.start(); + objectMonitor.start(); + } + + @Override + public void stop() throws Exception { + objectMonitor.stop(); + Managed.super.stop(); + callDnsRecords.getAndSet(null); + } + + @Override + public CallDnsRecords get() { + return this.callDnsRecords.get(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTable.java b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTable.java new file mode 100644 index 00000000..ea3b7b5a --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTable.java @@ -0,0 +1,193 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import javax.validation.constraints.NotBlank; +import javax.validation.constraints.NotNull; +import java.math.BigInteger; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.TreeMap; +import java.util.function.Function; +import java.util.stream.Stream; + +public class CallRoutingTable { + private final TreeMap>> ipv4Map; + private final TreeMap>> ipv6Map; + private final Map> geoToDatacenter; + + public CallRoutingTable( + Map> ipv4SubnetToDatacenter, + Map> ipv6SubnetToDatacenter, + Map> geoToDatacenter + ) { + this.ipv4Map = new TreeMap<>(); + for (Map.Entry> t : ipv4SubnetToDatacenter.entrySet()) { + if (!this.ipv4Map.containsKey(t.getKey().cidrBlockSize())) { + this.ipv4Map.put(t.getKey().cidrBlockSize(), new HashMap<>()); + } + this.ipv4Map + .get(t.getKey().cidrBlockSize()) + .put(t.getKey().subnet(), t.getValue()); + } + + this.ipv6Map = new TreeMap<>(); + for (Map.Entry> t : ipv6SubnetToDatacenter.entrySet()) { + if (!this.ipv6Map.containsKey(t.getKey().cidrBlockSize())) { + this.ipv6Map.put(t.getKey().cidrBlockSize(), new HashMap<>()); + } + this.ipv6Map + .get(t.getKey().cidrBlockSize()) + .put(t.getKey().subnet(), t.getValue()); + } + + this.geoToDatacenter = geoToDatacenter; + } + + public static CallRoutingTable empty() { + return new CallRoutingTable(Map.of(), Map.of(), Map.of()); + } + + public enum Protocol { + v4, + v6 + } + + public record GeoKey( + @NotBlank String continent, + @NotBlank String country, + @NotNull Optional subdivision, + @NotBlank Protocol protocol + ) {} + + /** + * Returns ordered list of fastest datacenters based on IP & Geo info. Prioritize the results based on subnet. + * Returns at most three, 2 by subnet and 1 by geo. Takes more from either bucket to hit 3. + */ + public List getDatacentersFor( + InetAddress address, + String continent, + String country, + Optional subdivision + ) { + final int NUM_DATACENTERS = 3; + + if(this.isEmpty()) { + return Collections.emptyList(); + } + + List dcsBySubnet = getDatacentersBySubnet(address); + List dcsByGeo = getDatacentersByGeo(continent, country, subdivision).stream() + .limit(NUM_DATACENTERS) + .filter(dc -> + (dcsBySubnet.isEmpty() || !dc.equals(dcsBySubnet.getFirst())) + && (dcsBySubnet.size() < 2 || !dc.equals(dcsBySubnet.get(1))) + ).toList(); + + return Stream.concat( + dcsBySubnet.stream().limit(dcsByGeo.isEmpty() ? NUM_DATACENTERS : NUM_DATACENTERS - 1), + dcsByGeo.stream()) + .limit(NUM_DATACENTERS) + .toList(); + } + + public boolean isEmpty() { + return this.ipv4Map.isEmpty() && this.ipv6Map.isEmpty() && this.geoToDatacenter.isEmpty(); + } + + /** + * Returns ordered list of fastest datacenters based on ip info. Prioritizes V4 connections. + */ + public List getDatacentersBySubnet(InetAddress address) throws IllegalArgumentException { + if(address instanceof Inet4Address) { + for(Map.Entry>> t: this.ipv4Map.descendingMap().entrySet()) { + int maskedIp = CidrBlock.IpV4CidrBlock.maskToSize((Inet4Address) address, t.getKey()); + if(t.getValue().containsKey(maskedIp)) { + return t.getValue().get(maskedIp); + } + } + } else if (address instanceof Inet6Address) { + for(Map.Entry>> t: this.ipv6Map.descendingMap().entrySet()) { + BigInteger maskedIp = CidrBlock.IpV6CidrBlock.maskToSize((Inet6Address) address, t.getKey()); + if(t.getValue().containsKey(maskedIp)) { + return t.getValue().get(maskedIp); + } + } + } else { + throw new IllegalArgumentException("Expected either an Inet4Address or Inet6Address"); + } + + return Collections.emptyList(); + } + + /** + * Returns ordered list of fastest datacenters based on geo info. Attempts to match based on subdivision, falls back + * to country based lookup. Does not attempt to look for nearby subdivisions. Prioritizes V4 connections. + */ + public List getDatacentersByGeo( + String continent, + String country, + Optional subdivision + ) { + GeoKey v4Key = new GeoKey(continent, country, subdivision, Protocol.v4); + List v4Options = this.geoToDatacenter.getOrDefault(v4Key, Collections.emptyList()); + List v4OptionsBackup = v4Options.isEmpty() && subdivision.isPresent() ? + this.geoToDatacenter.getOrDefault( + new GeoKey(continent, country, Optional.empty(), Protocol.v4), + Collections.emptyList()) + : Collections.emptyList(); + + GeoKey v6Key = new GeoKey(continent, country, subdivision, Protocol.v6); + List v6Options = this.geoToDatacenter.getOrDefault(v6Key, Collections.emptyList()); + List v6OptionsBackup = v6Options.isEmpty() && subdivision.isPresent() ? + this.geoToDatacenter.getOrDefault( + new GeoKey(continent, country, Optional.empty(), Protocol.v6), + Collections.emptyList()) + : Collections.emptyList(); + + return Stream.of( + v4Options.stream(), + v6Options.stream(), + v4OptionsBackup.stream(), + v6OptionsBackup.stream() + ) + .flatMap(Function.identity()) + .distinct() + .toList(); + } + + public String toSummaryString() { + return String.format( + "[Ipv4Table=%s rows, Ipv6Table=%s rows, GeoTable=%s rows]", + ipv4Map.size(), + ipv6Map.size(), + geoToDatacenter.size() + ); + } + + @Override + public boolean equals(final Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + CallRoutingTable that = (CallRoutingTable) o; + return Objects.equals(ipv4Map, that.ipv4Map) && Objects.equals(ipv6Map, that.ipv6Map) && Objects.equals( + geoToDatacenter, that.geoToDatacenter); + } + + @Override + public int hashCode() { + return Objects.hash(ipv4Map, ipv6Map, geoToDatacenter); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableManager.java new file mode 100644 index 00000000..ab28e85f --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableManager.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import io.dropwizard.lifecycle.Managed; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.s3.S3ObjectMonitor; +import javax.annotation.Nonnull; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +public class CallRoutingTableManager implements Supplier, Managed { + + private final S3ObjectMonitor objectMonitor; + + private final AtomicReference routingTable = new AtomicReference<>(); + + private final String tableTag; + + private final Timer refreshTimer; + + private static final Logger log = LoggerFactory.getLogger(CallRoutingTableManager.class); + + public CallRoutingTableManager( + @Nonnull final ScheduledExecutorService executorService, + @Nonnull final MonitoredS3ObjectConfiguration configuration, + @Nonnull final String tableTag + ){ + this.objectMonitor = new S3ObjectMonitor( + configuration.s3Region(), + configuration.s3Bucket(), + configuration.objectKey(), + configuration.maxSize(), + executorService, + configuration.refreshInterval(), + this::handleDatabaseChanged + ); + + this.tableTag = tableTag; + this.routingTable.set(CallRoutingTable.empty()); + this.refreshTimer = Metrics.timer(MetricsUtil.name(CallRoutingTableManager.class), tableTag); + } + + private void handleDatabaseChanged(final InputStream inputStream) { + refreshTimer.record(() -> { + try(InputStreamReader reader = new InputStreamReader(inputStream)) { + CallRoutingTable newTable = CallRoutingTableParser.fromJson(reader); + this.routingTable.set(newTable); + log.info("Replaced {} call routing table: {}", tableTag, newTable.toSummaryString()); + } catch (final IOException e) { + log.error("Failed to parse and update {} call routing table", tableTag); + } + }); + } + + @Override + public void start() throws Exception { + Managed.super.start(); + objectMonitor.start(); + } + + @Override + public void stop() throws Exception { + Managed.super.stop(); + objectMonitor.stop(); + routingTable.getAndSet(null); + } + + @Override + public CallRoutingTable get() { + return this.routingTable.get(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableParser.java b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableParser.java new file mode 100644 index 00000000..a3538b6d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableParser.java @@ -0,0 +1,185 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import com.fasterxml.jackson.core.StreamReadFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.json.JsonMapper; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.Reader; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +final class CallRoutingTableParser { + + private final static int IPV4_DEFAULT_BLOCK_SIZE = 24; + private final static int IPV6_DEFAULT_BLOCK_SIZE = 48; + private static final ObjectMapper objectMapper = JsonMapper.builder() + .enable(StreamReadFeature.INCLUDE_SOURCE_IN_LOCATION) + .build(); + + /** Used for parsing JSON */ + private static class RawCallRoutingTable { + public Map> ipv4GeoToDataCenters = Map.of(); + public Map> ipv6GeoToDataCenters = Map.of(); + public Map> ipv4SubnetsToDatacenters = Map.of(); + public Map> ipv6SubnetsToDatacenters = Map.of(); + } + + private final static String WHITESPACE_REGEX = "\\s+"; + + public static CallRoutingTable fromJson(final Reader inputReader) throws IOException { + try (final BufferedReader reader = new BufferedReader(inputReader)) { + RawCallRoutingTable rawTable = objectMapper.readValue(reader, RawCallRoutingTable.class); + + Map> ipv4SubnetToDatacenter = rawTable.ipv4SubnetsToDatacenters + .entrySet() + .stream() + .collect(Collectors.toUnmodifiableMap( + e -> (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock(e.getKey(), IPV4_DEFAULT_BLOCK_SIZE), + Map.Entry::getValue + )); + + Map> ipv6SubnetToDatacenter = rawTable.ipv6SubnetsToDatacenters + .entrySet() + .stream() + .collect(Collectors.toUnmodifiableMap( + e -> (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock(e.getKey(), IPV6_DEFAULT_BLOCK_SIZE), + Map.Entry::getValue + )); + + Map> geoToDatacenter = Stream.concat( + rawTable.ipv4GeoToDataCenters + .entrySet() + .stream() + .map(e -> Map.entry(parseRawGeoKey(e.getKey(), CallRoutingTable.Protocol.v4), e.getValue())), + rawTable.ipv6GeoToDataCenters + .entrySet() + .stream() + .map(e -> Map.entry(parseRawGeoKey(e.getKey(), CallRoutingTable.Protocol.v6), e.getValue())) + ).collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); + + return new CallRoutingTable( + ipv4SubnetToDatacenter, + ipv6SubnetToDatacenter, + geoToDatacenter + ); + } + } + + private static CallRoutingTable.GeoKey parseRawGeoKey(String rawKey, CallRoutingTable.Protocol protocol) { + String[] splits = rawKey.split("-"); + if (splits.length < 2 || splits.length > 3) { + throw new IllegalArgumentException("Invalid raw key"); + } + + Optional subdivision = splits.length < 3 ? Optional.empty() : Optional.of(splits[2]); + return new CallRoutingTable.GeoKey(splits[0], splits[1], subdivision, protocol); + } + + /** + * Parses a call routing table in TSV format. Example below - see tests for more examples: + 192.0.2.0/24 northamerica-northeast1 + 198.51.100.0/24 us-south1 + 203.0.113.0/24 asia-southeast1 + + 2001:db8:b0a9::/48 us-east4 + 2001:db8:b0f5::/48 us-central1 northamerica-northeast1 us-east4 + 2001:db8:9406::/48 us-east1 us-central1 + + SA-SR-v4 us-east1 us-east4 + SA-SR-v6 us-east1 us-south1 + SA-UY-v4 southamerica-west1 southamerica-east1 europe-west3 + SA-UY-v6 southamerica-west1 europe-west4 + SA-VE-v4 us-east1 us-east4 us-south1 + SA-VE-v6 us-east1 northamerica-northeast1 us-east4 + ZZ-ZZ-v4 asia-south1 europe-southwest1 australia-southeast1 + */ + public static CallRoutingTable fromTsv(final Reader inputReader) throws IOException { + try (final BufferedReader reader = new BufferedReader(inputReader)) { + // use maps to silently dedupe CidrBlocks + Map> ipv4Map = new HashMap<>(); + Map> ipv6Map = new HashMap<>(); + Map> ipGeoTable = new HashMap<>(); + String line; + while((line = reader.readLine()) != null) { + if(line.isBlank()) { + continue; + } + + List splits = Arrays.stream(line.split(WHITESPACE_REGEX)).filter(s -> !s.isBlank()).toList(); + if (splits.size() < 2) { + throw new IllegalStateException("Invalid row, expected some key and list of values"); + } + + List datacenters = splits.subList(1, splits.size()); + switch (guessLineType(splits)) { + case v4 -> { + CidrBlock cidrBlock = CidrBlock.parseCidrBlock(splits.getFirst()); + if(!(cidrBlock instanceof CidrBlock.IpV4CidrBlock)) { + throw new IllegalArgumentException("Expected an ipv4 cidr block"); + } + ipv4Map.put((CidrBlock.IpV4CidrBlock) cidrBlock, datacenters); + } + case v6 -> { + CidrBlock cidrBlock = CidrBlock.parseCidrBlock(splits.getFirst()); + if(!(cidrBlock instanceof CidrBlock.IpV6CidrBlock)) { + throw new IllegalArgumentException("Expected an ipv6 cidr block"); + } + ipv6Map.put((CidrBlock.IpV6CidrBlock) cidrBlock, datacenters); + } + case Geo -> { + String[] geo = splits.getFirst().split("-"); + if(geo.length < 3) { + throw new IllegalStateException("Geo row key invalid, expected atleast continent, country, and protocol"); + } + String continent = geo[0]; + String country = geo[1]; + Optional subdivision = geo.length > 3 ? Optional.of(geo[2]) : Optional.empty(); + CallRoutingTable.Protocol protocol = CallRoutingTable.Protocol.valueOf(geo[geo.length - 1].toLowerCase()); + CallRoutingTable.GeoKey tableKey = new CallRoutingTable.GeoKey( + continent, + country, + subdivision, + protocol + ); + ipGeoTable.put(tableKey, datacenters); + } + } + } + + return new CallRoutingTable( + ipv4Map, + ipv6Map, + ipGeoTable + ); + } + } + + private static LineType guessLineType(List splits) { + String first = splits.getFirst(); + if (first.contains("-")) { + return LineType.Geo; + } else if(first.contains(":")) { + return LineType.v6; + } else if (first.contains(".")) { + return LineType.v4; + } + + throw new IllegalArgumentException(String.format("Invalid line, could not determine type from '%s'", first)); + } + + private enum LineType { + v4, v6, Geo + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CidrBlock.java b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CidrBlock.java new file mode 100644 index 00000000..e3af1b2a --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/CidrBlock.java @@ -0,0 +1,137 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import java.math.BigInteger; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.UnknownHostException; + +/** + * Can be used to check if an IP is in the CIDR block + */ +public interface CidrBlock { + + boolean ipInBlock(InetAddress address); + + static CidrBlock parseCidrBlock(String cidrBlock, int defaultBlockSize) { + String[] splits = cidrBlock.split("/"); + if(splits.length > 2) { + throw new IllegalArgumentException("Invalid cidr block format, expected {address}/{blocksize}"); + } + + try { + int blockSize = splits.length == 2 ? Integer.parseInt(splits[1]) : defaultBlockSize; + return parseCidrBlockInner(splits[0], blockSize); + } catch (NumberFormatException e) { + throw new IllegalArgumentException(String.format("Invalid block size specified: '%s'", splits[1])); + } + } + + static CidrBlock parseCidrBlock(String cidrBlock) { + String[] splits = cidrBlock.split("/"); + if (splits.length != 2) { + throw new IllegalArgumentException("Invalid cidr block format, expected {address}/{blocksize}"); + } + + try { + int blockSize = Integer.parseInt(splits[1]); + return parseCidrBlockInner(splits[0], blockSize); + } catch (NumberFormatException e) { + throw new IllegalArgumentException(String.format("Invalid block size specified: '%s'", splits[1])); + } + } + + private static CidrBlock parseCidrBlockInner(String rawAddress, int blockSize) { + try { + InetAddress address = InetAddress.getByName(rawAddress); + if(address instanceof Inet4Address) { + return IpV4CidrBlock.of((Inet4Address) address, blockSize); + } else if (address instanceof Inet6Address) { + return IpV6CidrBlock.of((Inet6Address) address, blockSize); + } else { + throw new IllegalArgumentException("Must be an ipv4 or ipv6 string"); + } + } catch (UnknownHostException e) { + throw new IllegalArgumentException(e); + } + } + + record IpV4CidrBlock(int subnet, int subnetMask, int cidrBlockSize) implements CidrBlock { + public static IpV4CidrBlock of(Inet4Address subnet, int cidrBlockSize) { + if(cidrBlockSize > 32 || cidrBlockSize < 0) { + throw new IllegalArgumentException("Invalid cidrBlockSize"); + } + + int subnetMask = mask(cidrBlockSize); + int maskedIp = ipToInt(subnet) & subnetMask; + return new IpV4CidrBlock(maskedIp, subnetMask, cidrBlockSize); + } + + public boolean ipInBlock(InetAddress address) { + if(!(address instanceof Inet4Address)) { + return false; + } + int ip = ipToInt((Inet4Address) address); + return (ip & subnetMask) == subnet; + } + + private static int ipToInt(Inet4Address address) { + byte[] octets = address.getAddress(); + return (octets[0] & 0xff) << 24 | + (octets[1] & 0xff) << 16 | + (octets[2] & 0xff) << 8 | + octets[3] & 0xff; + } + + private static int mask(int cidrBlockSize) { + return (int) (-1L << (32 - cidrBlockSize)); + } + + public static int maskToSize(Inet4Address address, int cidrBlockSize) { + return ipToInt(address) & mask(cidrBlockSize); + } + } + + record IpV6CidrBlock(BigInteger subnet, BigInteger subnetMask, int cidrBlockSize) implements CidrBlock { + + private static final BigInteger MINUS_ONE = BigInteger.valueOf(-1); + + public static IpV6CidrBlock of(Inet6Address subnet, int cidrBlockSize) { + if(cidrBlockSize > 128 || cidrBlockSize < 0) { + throw new IllegalArgumentException("Invalid cidrBlockSize"); + } + + BigInteger subnetMask = mask(cidrBlockSize); + BigInteger maskedIp = ipToInt(subnet).and(subnetMask); + return new IpV6CidrBlock(maskedIp, subnetMask, cidrBlockSize); + } + + public boolean ipInBlock(InetAddress address) { + if(!(address instanceof Inet6Address)) { + return false; + } + BigInteger ip = ipToInt((Inet6Address) address); + return ip.and(subnetMask).equals(subnet); + } + + private static BigInteger ipToInt(Inet6Address ipAddress) { + byte[] octets = ipAddress.getAddress(); + assert octets.length == 16; + + return new BigInteger(octets); + } + + private static BigInteger mask(int cidrBlockSize) { + return MINUS_ONE.shiftLeft(128 - cidrBlockSize); + } + + public static BigInteger maskToSize(Inet6Address address, int cidrBlockSize) { + return ipToInt(address).and(mask(cidrBlockSize)); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/DynamicConfigTurnRouter.java b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/DynamicConfigTurnRouter.java new file mode 100644 index 00000000..b8223067 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/DynamicConfigTurnRouter.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import org.whispersystems.textsecuregcm.configuration.TurnUriConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicTurnConfiguration; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.util.WeightedRandomSelect; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.UUID; + +/** Uses DynamicConfig to help route a turn request */ +public class DynamicConfigTurnRouter { + + private static final Random rng = new Random(); + + public static final long RANDOMIZE_RATE_BASIS = 100_000; + + private final DynamicConfigurationManager dynamicConfigurationManager; + + public DynamicConfigTurnRouter(final DynamicConfigurationManager dynamicConfigurationManager) { + this.dynamicConfigurationManager = dynamicConfigurationManager; + } + + public List targetedUrls(final UUID aci) { + final DynamicTurnConfiguration turnConfig = dynamicConfigurationManager.getConfiguration().getTurnConfiguration(); + + final Optional enrolled = turnConfig.getUriConfigs().stream() + .filter(config -> config.getEnrolledAcis().contains(aci)) + .findFirst(); + + return enrolled + .map(turnUriConfiguration -> turnUriConfiguration.getUris().stream().toList()) + .orElse(Collections.emptyList()); + } + + public List randomUrls() { + final DynamicTurnConfiguration turnConfig = dynamicConfigurationManager.getConfiguration().getTurnConfiguration(); + + // select from turn server sets by weighted choice + return WeightedRandomSelect.select(turnConfig + .getUriConfigs() + .stream() + .map(c -> new Pair<>(c.getUris(), c.getWeight())).toList()); + } + + public String getHostname() { + final DynamicTurnConfiguration turnConfig = dynamicConfigurationManager.getConfiguration().getTurnConfiguration(); + return turnConfig.getHostname(); + } + + public long getRandomizeRate() { + final DynamicTurnConfiguration turnConfig = dynamicConfigurationManager.getConfiguration().getTurnConfiguration(); + return turnConfig.getRandomizeRate(); + } + + public boolean shouldRandomize() { + long rate = getRandomizeRate(); + return rate >= RANDOMIZE_RATE_BASIS || rng.nextLong(0, DynamicConfigTurnRouter.RANDOMIZE_RATE_BASIS) < rate; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/TurnCallRouter.java b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/TurnCallRouter.java new file mode 100644 index 00000000..94055d84 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/TurnCallRouter.java @@ -0,0 +1,149 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import com.maxmind.geoip2.DatabaseReader; +import com.maxmind.geoip2.exception.GeoIp2Exception; +import com.maxmind.geoip2.model.CityResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.util.Util; +import javax.annotation.Nonnull; +import java.io.IOException; +import java.net.InetAddress; +import java.util.*; +import java.util.function.Supplier; +import java.util.stream.Stream; + +/** + * Returns routes based on performance tables, manually routing tables, and target routing. Falls back to a random Turn + * instance that the server knows about. + */ +public class TurnCallRouter { + + private final Logger logger = LoggerFactory.getLogger(TurnCallRouter.class); + + private final Supplier callDnsRecords; + private final Supplier performanceRouting; + private final Supplier manualRouting; + private final DynamicConfigTurnRouter configTurnRouter; + private final Supplier geoIp; + + public TurnCallRouter( + @Nonnull Supplier callDnsRecords, + @Nonnull Supplier performanceRouting, + @Nonnull Supplier manualRouting, + @Nonnull DynamicConfigTurnRouter configTurnRouter, + @Nonnull Supplier geoIp + ) { + this.performanceRouting = performanceRouting; + this.callDnsRecords = callDnsRecords; + this.manualRouting = manualRouting; + this.configTurnRouter = configTurnRouter; + this.geoIp = geoIp; + } + + /** + * Gets Turn Instance addresses. Returns both the IPv4 and IPv6 addresses. Prioritizes V4 connections. + * @param aci aci of client + * @param clientAddress IP address to base routing on + * @param instanceLimit max instances to return options for + */ + public TurnServerOptions getRoutingFor( + @Nonnull final UUID aci, + @Nonnull final Optional clientAddress, + final int instanceLimit + ) { + try { + return getRoutingForInner(aci, clientAddress, instanceLimit); + } catch(Exception e) { + logger.error("Failed to perform routing", e); + return new TurnServerOptions(this.configTurnRouter.getHostname(), null, this.configTurnRouter.randomUrls()); + } + } + + TurnServerOptions getRoutingForInner( + @Nonnull final UUID aci, + @Nonnull final Optional clientAddress, + final int instanceLimit + ) { + if (instanceLimit < 1) { + throw new IllegalArgumentException("Limit cannot be less than one"); + } + + String hostname = this.configTurnRouter.getHostname(); + + List targetedUrls = this.configTurnRouter.targetedUrls(aci); + if(!targetedUrls.isEmpty()) { + return new TurnServerOptions(hostname, null, targetedUrls); + } + + if(clientAddress.isEmpty() || this.configTurnRouter.shouldRandomize()) { + return new TurnServerOptions(hostname, null, this.configTurnRouter.randomUrls()); + } + + CityResponse geoInfo; + try { + geoInfo = geoIp.get().city(clientAddress.get()); + } catch (IOException | GeoIp2Exception e) { + throw new RuntimeException(e); + } + Optional subdivision = !geoInfo.getSubdivisions().isEmpty() + ? Optional.of(geoInfo.getSubdivisions().getFirst().getIsoCode()) + : Optional.empty(); + + List datacenters = this.manualRouting.get().getDatacentersFor( + clientAddress.get(), + geoInfo.getContinent().getCode(), + geoInfo.getCountry().getIsoCode(), + subdivision + ); + + if (datacenters.isEmpty()){ + datacenters = this.performanceRouting.get().getDatacentersFor( + clientAddress.get(), + geoInfo.getContinent().getCode(), + geoInfo.getCountry().getIsoCode(), + subdivision + ); + } + List urlsWithIps = getUrlsForInstances(selectInstances(datacenters, instanceLimit)); + return new TurnServerOptions(hostname, urlsWithIps, this.configTurnRouter.randomUrls()); + } + + private List selectInstances(List datacenters, int limit) { + if(datacenters.isEmpty() || limit == 0) { + return Collections.emptyList(); + } + + CallDnsRecords dnsRecords = this.callDnsRecords.get(); + List ipv4Selection = datacenters.stream() + .flatMap(dc -> Util.randomNOfStable(dnsRecords.aByRegion().get(dc), 2).stream()) + .toList(); + List ipv6Selection = datacenters.stream() + .flatMap(dc -> Util.randomNOfStable(dnsRecords.aaaaByRegion().get(dc), 2).stream()) + .toList(); + if (ipv4Selection.size() < ipv6Selection.size()) { + ipv4Selection = ipv4Selection.stream().limit(limit / 2).toList(); + ipv6Selection = ipv6Selection.stream().limit(limit - ipv4Selection.size()).toList(); + } else { + ipv6Selection = ipv6Selection.stream().limit(limit / 2).toList(); + ipv4Selection = ipv4Selection.stream().limit(limit - ipv6Selection.size()).toList(); + } + + return Stream.concat(ipv4Selection.stream(), ipv6Selection.stream()).map(InetAddress::getHostAddress).toList(); + } + + private static List getUrlsForInstances(List instanceIps) { + return instanceIps.stream().flatMap(ip -> Stream.of( + String.format("stun:%s", ip), + String.format("turn:%s", ip), + String.format("turn:%s:80?transport=tcp", ip), + String.format("turns:%s:443?transport=tcp", ip) + ) + ).toList(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/TurnServerOptions.java b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/TurnServerOptions.java new file mode 100644 index 00000000..a1381ea2 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/calls/routing/TurnServerOptions.java @@ -0,0 +1,11 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import java.util.List; + +public record TurnServerOptions(String hostname, List urlsWithIps, List urlsWithHostname) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicTurnConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicTurnConfiguration.java index e34a2143..c811377a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicTurnConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicTurnConfiguration.java @@ -13,10 +13,28 @@ import org.whispersystems.textsecuregcm.configuration.TurnUriConfiguration; public class DynamicTurnConfiguration { + @JsonProperty + private String hostname; + + /** + * Rate at which to prioritize a random turn URL to exercise all endpoints. + * Based on a 100,000 basis, where 100,000 == 100%. + */ + @JsonProperty + private long randomizeRate = 5_000; + @JsonProperty private List<@Valid TurnUriConfiguration> uriConfigs = Collections.emptyList(); public List getUriConfigs() { return uriConfigs; } + + public long getRandomizeRate() { + return randomizeRate; + } + + public String getHostname() { + return hostname; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index 460324e2..588d393e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -93,6 +93,7 @@ public class AccountController { this.usernameHashZkProofVerifier = usernameHashZkProofVerifier; } + @Deprecated @GET @Path("/turn/") @Produces(MediaType.APPLICATION_JSON) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingController.java new file mode 100644 index 00000000..b365fb41 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingController.java @@ -0,0 +1,85 @@ +package org.whispersystems.textsecuregcm.controllers; + +import io.dropwizard.auth.Auth; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.responses.ApiResponse; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.Optional; +import java.util.UUID; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.TurnToken; +import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; +import org.whispersystems.textsecuregcm.calls.routing.TurnServerOptions; +import org.whispersystems.textsecuregcm.calls.routing.TurnCallRouter; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; + +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + +@Path("/v1/calling") +@io.swagger.v3.oas.annotations.tags.Tag(name = "Calling") +public class CallRoutingController { + + private static final int TURN_INSTANCE_LIMIT = 6; + private static final Counter INVALID_IP_COUNTER = Metrics.counter(name(CallRoutingController.class, "invalidIP")); + private static final Logger log = LoggerFactory.getLogger(CallRoutingController.class); + private final RateLimiters rateLimiters; + private final TurnCallRouter turnCallRouter; + private final TurnTokenGenerator tokenGenerator; + + public CallRoutingController( + final RateLimiters rateLimiters, + final TurnCallRouter turnCallRouter, + final TurnTokenGenerator tokenGenerator + ) { + this.rateLimiters = rateLimiters; + this.turnCallRouter = turnCallRouter; + this.tokenGenerator = tokenGenerator; + } + + @GET + @Path("/relays") + @Produces(MediaType.APPLICATION_JSON) + @Operation( + summary = "Get 1:1 calling relay options for the client", + description = """ + Get 1:1 relay addresses in IpV4, Ipv6, and URL formats. + """ + ) + @ApiResponse(responseCode = "200", description = "`JSON` with call endpoints.", useReturnTypeSchema = true) + @ApiResponse(responseCode = "400", description = "Invalid get call endpoint request.") + @ApiResponse(responseCode = "401", description = "Account authentication check failed.") + @ApiResponse(responseCode = "422", description = "Invalid request format.") + @ApiResponse(responseCode = "429", description = "Ratelimited.") + public TurnToken getCallingRelays( + final @Auth AuthenticatedAccount auth, + @Context ContainerRequestContext requestContext + ) throws RateLimitExceededException { + UUID aci = auth.getAccount().getUuid(); + rateLimiters.getCallEndpointLimiter().validate(aci); + + Optional address = Optional.empty(); + try { + final String remoteAddress = (String) requestContext.getProperty( + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); + address = Optional.of(InetAddress.getByName(remoteAddress)); + } catch (UnknownHostException e) { + INVALID_IP_COUNTER.increment(); + } + + TurnServerOptions options = turnCallRouter.getRoutingFor(aci, address, TURN_INSTANCE_LIMIT); + return tokenGenerator.generateWithTurnServerOptions(options); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/geo/MaxMindDatabaseManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/geo/MaxMindDatabaseManager.java new file mode 100644 index 00000000..3f4cdd3f --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/geo/MaxMindDatabaseManager.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.geo; + +import com.maxmind.db.CHMCache; +import com.maxmind.geoip2.DatabaseReader; +import com.maxmind.geoip2.GeoIp2Provider; +import io.dropwizard.lifecycle.Managed; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; +import org.apache.commons.compress.archivers.ArchiveEntry; +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; +import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.s3.S3ObjectMonitor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import javax.annotation.Nonnull; +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +public class MaxMindDatabaseManager implements Supplier, Managed { + + private final S3ObjectMonitor databaseMonitor; + + private final AtomicReference databaseReader = new AtomicReference<>(); + + private final String databaseTag; + + private final Timer refreshTimer; + + private static final Logger log = LoggerFactory.getLogger(MaxMindDatabaseManager.class); + + public MaxMindDatabaseManager( + @Nonnull final ScheduledExecutorService executorService, + @Nonnull final MonitoredS3ObjectConfiguration configuration, + @Nonnull final String databaseTag + ){ + this.databaseMonitor = new S3ObjectMonitor( + configuration.s3Region(), + configuration.s3Bucket(), + configuration.objectKey(), + configuration.maxSize(), + executorService, + configuration.refreshInterval(), + this::handleDatabaseChanged + ); + + this.databaseTag = databaseTag; + this.refreshTimer = Metrics.timer(MetricsUtil.name(MaxMindDatabaseManager.class), "refresh", databaseTag); + } + + private void handleDatabaseChanged(final InputStream inputStream) { + refreshTimer.record(() -> { + boolean foundDatabaseEntry = false; + + try (final InputStream bufferedInputStream = new BufferedInputStream(inputStream); + final GzipCompressorInputStream gzipInputStream = new GzipCompressorInputStream(bufferedInputStream); + final TarArchiveInputStream tarInputStream = new TarArchiveInputStream(gzipInputStream)) { + + ArchiveEntry nextEntry; + + while ((nextEntry = tarInputStream.getNextEntry()) != null) { + if (nextEntry.getName().toLowerCase().endsWith(".mmdb")) { + foundDatabaseEntry = true; + + final DatabaseReader oldReader = databaseReader.getAndSet( + new DatabaseReader.Builder(tarInputStream).withCache(new CHMCache()).build() + ); + if (oldReader != null) { + oldReader.close(); + } + break; + } + } + } catch (final IOException e) { + log.error(String.format("Failed to load MaxMind database, tag %s", databaseTag)); + } + + if (!foundDatabaseEntry) { + log.warn(String.format("No .mmdb entry loaded from input stream, tag %s", databaseTag)); + } + }); + } + + @Override + public void start() throws Exception { + Managed.super.start(); + databaseMonitor.start(); + } + + @Override + public void stop() throws Exception { + Managed.super.stop(); + databaseMonitor.stop(); + + final DatabaseReader reader = databaseReader.getAndSet(null); + if(reader != null) { + reader.close(); + } + } + + @Override + public DatabaseReader get() { + return this.databaseReader.get(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index 6e50a1b4..18d0037d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -49,6 +49,7 @@ public class RateLimiters extends BaseRateLimiters { SET_BACKUP_ID("setBackupId", true, new RateLimiterConfig(2, Duration.ofDays(7))), PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144))), PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12))), + GET_CALLING_RELAYS("getCallingRelays", false, new RateLimiterConfig(100, Duration.ofMinutes(10))), CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofMinutes(15))), INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000))), EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15))), @@ -220,6 +221,10 @@ public class RateLimiters extends BaseRateLimiters { return forDescriptor(For.CREATE_CALL_LINK); } + public RateLimiter getCallEndpointLimiter() { + return forDescriptor(For.GET_CALLING_RELAYS); + } + public RateLimiter getInboundMessageBytes() { return forDescriptor(For.INBOUND_MESSAGE_BYTES); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/Util.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/Util.java index 19ebbc3d..72d01d11 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/Util.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/Util.java @@ -10,13 +10,19 @@ import com.google.i18n.phonenumbers.PhoneNumberUtil.PhoneNumberFormat; import com.google.i18n.phonenumbers.Phonenumber.PhoneNumber; import java.time.Clock; import java.time.Duration; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Locale.LanguageRange; import java.util.Optional; +import java.util.Random; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Function; +import java.util.random.RandomGenerator; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -26,6 +32,8 @@ import org.apache.commons.lang3.StringUtils; public class Util { + private static final RandomGenerator rng = new Random(); + private static final Pattern COUNTRY_CODE_PATTERN = Pattern.compile("^\\+([17]|2[07]|3[0123469]|4[013456789]|5[12345678]|6[0123456]|8[1246]|9[0123458]|\\d{3})"); private static final PhoneNumberUtil PHONE_NUMBER_UTIL = PhoneNumberUtil.getInstance(); @@ -160,4 +168,47 @@ public class Util { return n == Long.MIN_VALUE ? 0 : Math.abs(n); } + /** + * Chooses min(values.size(), n) random values. + *
+ * Copies the input Array - use for small lists only or for when n/values.size() is near 1. + */ + public static List randomNOf(List values, int n) { + if(values == null || values.isEmpty()) { + return Collections.emptyList(); + } + + List result = new ArrayList<>(values); + if(n >= values.size()) { + return result; + } + + Collections.shuffle(result); + return result.stream().limit(n).toList(); + } + + /** + * Chooses min(values.size(), n) random values. Return value is in stable order from input values. + * Not uniform random, but good enough. + *
+ * Does NOT copy the input Array. + */ + public static List randomNOfStable(List values, int n) { + if(values == null || values.isEmpty()) { + return Collections.emptyList(); + } + if(n >= values.size()) { + return values; + } + + Set indices = new HashSet<>(rng.ints(0, values.size()).distinct().limit(n).boxed().toList()); + List result = new ArrayList(n); + for(int i = 0; i < values.size() && result.size() < n; i++) { + if(indices.contains(i)) { + result.add(values.get(i)); + } + } + + return result; + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CallDnsRecordsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CallDnsRecordsManagerTest.java new file mode 100644 index 00000000..24fb6075 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CallDnsRecordsManagerTest.java @@ -0,0 +1,101 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import org.junit.jupiter.api.Test; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.stream.Stream; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +public class CallDnsRecordsManagerTest { + + @Test + public void testParseDnsRecords() throws IOException { + var input = """ + { + "aByRegion": { + "datacenter-1": [ + "127.0.0.1" + ], + "datacenter-2": [ + "127.0.0.2", + "127.0.0.3" + ], + "datacenter-3": [ + "127.0.0.4", + "127.0.0.5" + ], + "datacenter-4": [ + "127.0.0.6", + "127.0.0.7" + ] + }, + "aaaaByRegion": { + "datacenter-1": [ + "2600:1111:2222:3333:0:20:0:0", + "2600:1111:2222:3333:0:21:0:0", + "2600:1111:2222:3333:0:22:0:0" + ], + "datacenter-2": [ + "2600:1111:2222:3333:0:23:0:0", + "2600:1111:2222:3333:0:24:0:0" + ], + "datacenter-3": [ + "2600:1111:2222:3333:0:25:0:0", + "2600:1111:2222:3333:0:26:0:0" + ], + "datacenter-4": [ + "2600:1111:2222:3333:0:27:0:0" + ] + } + } + """; + + var actual = CallDnsRecordsManager.parseRecords(new ByteArrayInputStream(input.getBytes(StandardCharsets.UTF_8))); + var expected = new CallDnsRecords( + Map.of( + "datacenter-1", Stream.of("127.0.0.1").map(this::getAddressByName).toList(), + "datacenter-2", Stream.of("127.0.0.2", "127.0.0.3").map(this::getAddressByName).toList(), + "datacenter-3", Stream.of("127.0.0.4", "127.0.0.5").map(this::getAddressByName).toList(), + "datacenter-4", Stream.of("127.0.0.6", "127.0.0.7").map(this::getAddressByName).toList() + ), + Map.of( + "datacenter-1", Stream.of( + "2600:1111:2222:3333:0:20:0:0", + "2600:1111:2222:3333:0:21:0:0", + "2600:1111:2222:3333:0:22:0:0" + ).map(this::getAddressByName).toList(), + "datacenter-2", Stream.of( + "2600:1111:2222:3333:0:23:0:0", + "2600:1111:2222:3333:0:24:0:0") + .map(this::getAddressByName).toList(), + "datacenter-3", Stream.of( + "2600:1111:2222:3333:0:25:0:0", + "2600:1111:2222:3333:0:26:0:0") + .map(this::getAddressByName).toList(), + "datacenter-4", Stream.of( + "2600:1111:2222:3333:0:27:0:0" + ).map(this::getAddressByName).toList() + ) + ); + + assertThat(actual).isEqualTo(expected); + } + + InetAddress getAddressByName(String ip) { + try { + return InetAddress.getByName(ip) ; + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableParserTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableParserTest.java new file mode 100644 index 00000000..a0427a59 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableParserTest.java @@ -0,0 +1,245 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.io.StringReader; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class CallRoutingTableParserTest { + + @Test + public void testParserSuccess() throws IOException { + var input = + """ + 192.1.12.0/24 datacenter-1 datacenter-2 datacenter-3 + 193.123.123.0/24 datacenter-1 datacenter-2 + 1.123.123.0/24 datacenter-1 + + 2001:db8:b0aa::/48 datacenter-1 + 2001:db8:b0ab::/48 datacenter-3 datacenter-1 datacenter-2 + 2001:db8:b0ac::/48 datacenter-2 datacenter-1 + + SA-SR-v4 datacenter-3 + SA-UY-v4 datacenter-3 datacenter-1 datacenter-2 + NA-US-VA-v6 datacenter-2 datacenter-1 + """; + var actual = CallRoutingTableParser.fromTsv(new StringReader(input)); + var expected = new CallRoutingTable( + Map.of( + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("192.1.12.0/24"), List.of("datacenter-1", "datacenter-2", "datacenter-3"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("193.123.123.0/24"), List.of("datacenter-1", "datacenter-2"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("1.123.123.0/24"), List.of("datacenter-1") + ), + Map.of( + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0aa::/48"), List.of("datacenter-1"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0ab::/48"), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0ac::/48"), List.of("datacenter-2", "datacenter-1") + ), + Map.of( + new CallRoutingTable.GeoKey("SA", "SR", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3"), + new CallRoutingTable.GeoKey("SA", "UY", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + new CallRoutingTable.GeoKey("NA", "US", Optional.of("VA"), CallRoutingTable.Protocol.v6), List.of("datacenter-2", "datacenter-1") + ) + ); + + assertThat(actual).isEqualTo(expected); + } + + @Test + public void testParserVariousWhitespaceSuccess() throws IOException { + var input = + """ + + 192.1.12.0/24\t \tdatacenter-1\t\t datacenter-2 datacenter-3 + \t193.123.123.0/24\tdatacenter-1\tdatacenter-2 + + + 1.123.123.0/24\t datacenter-1 + 2001:db8:b0aa::/48\t \tdatacenter-1 + 2001:db8:b0ab::/48 \tdatacenter-3\tdatacenter-1 datacenter-2 + 2001:db8:b0ac::/48\tdatacenter-2\tdatacenter-1 + + + + + + + SA-SR-v4 datacenter-3 + + + + + SA-UY-v4\tdatacenter-3\tdatacenter-1\tdatacenter-2 + NA-US-VA-v6 datacenter-2 \tdatacenter-1 + """; + var actual = CallRoutingTableParser.fromTsv(new StringReader(input)); + var expected = new CallRoutingTable( + Map.of( + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("192.1.12.0/24"), List.of("datacenter-1", "datacenter-2", "datacenter-3"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("193.123.123.0/24"), List.of("datacenter-1", "datacenter-2"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("1.123.123.0/24"), List.of("datacenter-1") + ), + Map.of( + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0aa::/48"), List.of("datacenter-1"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0ab::/48"), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0ac::/48"), List.of("datacenter-2", "datacenter-1") + ), + Map.of( + new CallRoutingTable.GeoKey("SA", "SR", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3"), + new CallRoutingTable.GeoKey("SA", "UY", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + new CallRoutingTable.GeoKey("NA", "US", Optional.of("VA"), CallRoutingTable.Protocol.v6), List.of("datacenter-2", "datacenter-1") + ) + ); + + assertThat(actual).isEqualTo(expected); + } + + @Test + public void testParserMissingSection() throws IOException { + var input = + """ + 192.1.12.0/24\t \tdatacenter-1\t\t datacenter-2 datacenter-3 + 193.123.123.0/24\tdatacenter-1\tdatacenter-2 + 1.123.123.0/24\t datacenter-1 + + SA-SR-v4 datacenter-3 + SA-UY-v4\tdatacenter-3\tdatacenter-1\tdatacenter-2 + NA-US-VA-v6 datacenter-2 \tdatacenter-1 + """; + var actual = CallRoutingTableParser.fromTsv(new StringReader(input)); + var expected = new CallRoutingTable( + Map.of( + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("192.1.12.0/24"), List.of("datacenter-1", "datacenter-2", "datacenter-3"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("193.123.123.0/24"), List.of("datacenter-1", "datacenter-2"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("1.123.123.0/24"), List.of("datacenter-1") + ), + Map.of(), + Map.of( + new CallRoutingTable.GeoKey("SA", "SR", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3"), + new CallRoutingTable.GeoKey("SA", "UY", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + new CallRoutingTable.GeoKey("NA", "US", Optional.of("VA"), CallRoutingTable.Protocol.v6), List.of("datacenter-2", "datacenter-1") + ) + ); + + assertThat(actual).isEqualTo(expected); + } + + + @Test + public void testParserMixedSections() throws IOException { + var input = + """ + + + 1.123.123.0/24\t datacenter-1 + 2001:db8:b0aa::/48\t \tdatacenter-1 + 2001:db8:b0ab::/48 \tdatacenter-3\tdatacenter-1 datacenter-2 + 2001:db8:b0ac::/48\tdatacenter-2\tdatacenter-1 + + + + 192.1.12.0/24\t \tdatacenter-1\t\t datacenter-2 datacenter-3 + 193.123.123.0/24\tdatacenter-1\tdatacenter-2 + + + + SA-SR-v4 datacenter-3 + + + + + SA-UY-v4\tdatacenter-3\tdatacenter-1\tdatacenter-2 + NA-US-VA-v6 datacenter-2 \tdatacenter-1 + """; + var actual = CallRoutingTableParser.fromTsv(new StringReader(input)); + var expected = new CallRoutingTable( + Map.of( + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("1.123.123.0/24"), List.of("datacenter-1"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("192.1.12.0/24"), List.of("datacenter-1", "datacenter-2", "datacenter-3"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("193.123.123.0/24"), List.of("datacenter-1", "datacenter-2") + ), + Map.of( + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0aa::/48"), List.of("datacenter-1"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0ab::/48"), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0ac::/48"), List.of("datacenter-2", "datacenter-1") + ), + Map.of( + new CallRoutingTable.GeoKey("SA", "SR", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3"), + new CallRoutingTable.GeoKey("SA", "UY", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + new CallRoutingTable.GeoKey("NA", "US", Optional.of("VA"), CallRoutingTable.Protocol.v6), List.of("datacenter-2", "datacenter-1") + ) + ); + + assertThat(actual).isEqualTo(expected); + } + + @Test + public void testJsonParserSuccess() throws IOException { + var input = + """ + { + "ipv4GeoToDataCenters": { + "SA-SR": ["datacenter-3"], + "SA-UY": ["datacenter-3", "datacenter-1", "datacenter-2"] + }, + "ipv6GeoToDataCenters": { + "NA-US-VA": ["datacenter-2", "datacenter-1"] + }, + "ipv4SubnetsToDatacenters": { + "192.1.12.0": ["datacenter-1", "datacenter-2", "datacenter-3"], + "193.123.123.0": ["datacenter-1", "datacenter-2"], + "1.123.123.0": ["datacenter-1"] + }, + "ipv6SubnetsToDatacenters": { + "2001:db8:b0aa::": ["datacenter-1"], + "2001:db8:b0ab::": ["datacenter-3", "datacenter-1", "datacenter-2"], + "2001:db8:b0ac::": ["datacenter-2", "datacenter-1"] + } + } + """; + var actual = CallRoutingTableParser.fromJson(new StringReader(input)); + var expected = new CallRoutingTable( + Map.of( + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("192.1.12.0/24"), List.of("datacenter-1", "datacenter-2", "datacenter-3"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("193.123.123.0/24"), List.of("datacenter-1", "datacenter-2"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("1.123.123.0/24"), List.of("datacenter-1") + ), + Map.of( + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0aa::/48"), List.of("datacenter-1"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0ab::/48"), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0ac::/48"), List.of("datacenter-2", "datacenter-1") + ), + Map.of( + new CallRoutingTable.GeoKey("SA", "SR", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3"), + new CallRoutingTable.GeoKey("SA", "UY", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + new CallRoutingTable.GeoKey("NA", "US", Optional.of("VA"), CallRoutingTable.Protocol.v6), List.of("datacenter-2", "datacenter-1") + ) + ); + + assertThat(actual).isEqualTo(expected); + } + + @Test + public void testParseVariousEdgeCases() throws IOException { + var input = + """ + { + "ipv4GeoToDataCenters": {}, + "ipv6GeoToDataCenters": {}, + "ipv4SubnetsToDatacenters": {}, + "ipv6SubnetsToDatacenters": {} + } + """; + assertThat(CallRoutingTableParser.fromJson(new StringReader(input))).isEqualTo(CallRoutingTable.empty()); + assertThat(CallRoutingTableParser.fromJson(new StringReader("{}"))).isEqualTo(CallRoutingTable.empty()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableTest.java new file mode 100644 index 00000000..f381411f --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CallRoutingTableTest.java @@ -0,0 +1,140 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import io.vavr.Tuple2; +import org.junit.jupiter.api.Test; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.UnknownHostException; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +public class CallRoutingTableTest { + + static final CallRoutingTable basicTable = new CallRoutingTable( + Map.of( + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("192.1.12.0/24"), List.of("datacenter-1", "datacenter-2", "datacenter-3"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("193.123.123.0/24"), List.of("datacenter-1", "datacenter-2"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("1.123.123.0/24"), List.of("datacenter-4") + ), + Map.of( + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0aa::/48"), List.of("datacenter-1"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0ab::/48"), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0ac::/48"), List.of("datacenter-2", "datacenter-1") + ), + Map.of( + new CallRoutingTable.GeoKey("SA", "SR", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3"), + new CallRoutingTable.GeoKey("SA", "UY", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + new CallRoutingTable.GeoKey("NA", "US", Optional.of("VA"), CallRoutingTable.Protocol.v6), List.of("datacenter-2", "datacenter-1"), + new CallRoutingTable.GeoKey("NA", "US", Optional.empty(), CallRoutingTable.Protocol.v6), List.of("datacenter-3", "datacenter-4") + ) + ); + + // has overlapping subnets + static final CallRoutingTable overlappingTable = new CallRoutingTable( + Map.of( + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("192.1.12.0/24"), List.of("datacenter-1", "datacenter-2", "datacenter-3"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("1.123.123.0/24"), List.of("datacenter-4"), + (CidrBlock.IpV4CidrBlock) CidrBlock.parseCidrBlock("1.123.0.0/16"), List.of("datacenter-1") + ), + Map.of( + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0aa::/48"), List.of("datacenter-1"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0ac::/48"), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + (CidrBlock.IpV6CidrBlock) CidrBlock.parseCidrBlock("2001:db8:b0a0::/44"), List.of("datacenter-2", "datacenter-1") + ), + Map.of( + new CallRoutingTable.GeoKey("SA", "SR", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3"), + new CallRoutingTable.GeoKey("SA", "UY", Optional.empty(), CallRoutingTable.Protocol.v4), List.of("datacenter-3", "datacenter-1", "datacenter-2"), + new CallRoutingTable.GeoKey("NA", "US", Optional.of("VA"), CallRoutingTable.Protocol.v6), List.of("datacenter-2", "datacenter-1") + ) + ); + + @Test + void testGetFastestDataCentersBySubnet() throws UnknownHostException { + var v4address = Inet4Address.getByName("1.123.123.1"); + var actualV4 = basicTable.getDatacentersBySubnet(v4address); + assertThat(actualV4).isEqualTo(List.of("datacenter-4")); + + var v6address = Inet6Address.getByName("2001:db8:b0ac:aaaa:aaaa:aaaa:aaaa:0001"); + var actualV6 = basicTable.getDatacentersBySubnet(v6address); + assertThat(actualV6).isEqualTo(List.of("datacenter-2", "datacenter-1")); + } + + @Test + void testGetFastestDataCentersBySubnetOverlappingTable() throws UnknownHostException { + var v4address = Inet4Address.getByName("1.123.123.1"); + var actualV4 = overlappingTable.getDatacentersBySubnet(v4address); + assertThat(actualV4).isEqualTo(List.of("datacenter-4")); + + var v6address = Inet6Address.getByName("2001:db8:b0ac:aaaa:aaaa:aaaa:aaaa:0001"); + var actualV6 = overlappingTable.getDatacentersBySubnet(v6address); + assertThat(actualV6).isEqualTo(List.of("datacenter-3", "datacenter-1", "datacenter-2")); + } + + @Test + void testGetFastestDataCentersByGeo() { + var actual = basicTable.getDatacentersByGeo("SA", "SR", Optional.empty()); + assertThat(actual).isEqualTo(List.of("datacenter-3")); + + var actualWithSubdvision = basicTable.getDatacentersByGeo("NA", "US", Optional.of("VA")); + assertThat(actualWithSubdvision).isEqualTo(List.of("datacenter-2", "datacenter-1")); + } + + @Test + void testGetFastestDataCentersByGeoFallback() { + var actualExactMatch = basicTable.getDatacentersByGeo("NA", "US", Optional.of("VA")); + assertThat(actualExactMatch).isEqualTo(List.of("datacenter-2", "datacenter-1")); + + var actualApproximateMatch = basicTable.getDatacentersByGeo("NA", "US", Optional.of("MD")); + assertThat(actualApproximateMatch).isEqualTo(List.of("datacenter-3", "datacenter-4")); + } + + @Test + void testGetFastestDatacentersPrioritizesSubnet() throws UnknownHostException { + var v4address = Inet4Address.getByName("1.123.123.1"); + var actual = basicTable.getDatacentersFor(v4address, "NA", "US", Optional.of("VA")); + assertThat(actual).isEqualTo(List.of("datacenter-4", "datacenter-2", "datacenter-1")); + } + + @Test + void testGetFastestDatacentersEmptySubnet() throws UnknownHostException { + var v4address = Inet4Address.getByName("200.200.123.1"); + var actual = basicTable.getDatacentersFor(v4address, "NA", "US", Optional.of("VA")); + assertThat(actual).isEqualTo(List.of("datacenter-2", "datacenter-1")); + } + + @Test + void testGetFastestDatacentersEmptySubnetTakesExtraFromGeo() throws UnknownHostException { + var v4address = Inet4Address.getByName("200.200.123.1"); + var actual = basicTable.getDatacentersFor(v4address, "SA", "UY", Optional.empty()); + assertThat(actual).isEqualTo(List.of("datacenter-3", "datacenter-1", "datacenter-2")); + } + + @Test + void testGetFastestDatacentersEmptyGeoResults() throws UnknownHostException { + var v4address = Inet4Address.getByName("1.123.123.1"); + var actual = basicTable.getDatacentersFor(v4address, "ZZ", "AA", Optional.empty()); + assertThat(actual).isEqualTo(List.of("datacenter-4")); + } + + @Test + void testGetFastestDatacentersEmptyGeoTakesFromSubnet() throws UnknownHostException { + var v4address = Inet4Address.getByName("192.1.12.1"); + var actual = basicTable.getDatacentersFor(v4address, "ZZ", "AA", Optional.empty()); + assertThat(actual).isEqualTo(List.of("datacenter-1", "datacenter-2", "datacenter-3")); + } + + @Test + void testGetFastestDatacentersDistinct() throws UnknownHostException { + var v6address = Inet6Address.getByName("2001:db8:b0ac:aaaa:aaaa:aaaa:aaaa:0001"); + var actual = basicTable.getDatacentersFor(v6address, "NA", "US", Optional.of("VA")); + assertThat(actual).isEqualTo(List.of("datacenter-2", "datacenter-1")); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CidrBlockTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CidrBlockTest.java new file mode 100644 index 00000000..74aa2d6d --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/CidrBlockTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import org.junit.jupiter.api.Test; + +import java.math.BigInteger; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.HexFormat; + +import static org.assertj.core.api.Assertions.assertThat; + +public class CidrBlockTest { + + private HexFormat hex = HexFormat.ofDelimiter(":").withLowerCase(); + + @Test + public void testIPv4CidrBlockParseSuccess() { + var actual = CidrBlock.parseCidrBlock("255.32.15.0/24"); + var expected = new CidrBlock.IpV4CidrBlock(0xFF_20_0F_00, 0xFFFFFF00, 24); + + assertThat(actual).isInstanceOf(CidrBlock.IpV4CidrBlock.class); + assertThat(actual).isEqualTo(expected); + } + + @Test + public void testIPv6CidrBlockParseSuccess() { + var actual = CidrBlock.parseCidrBlock("2001:db8:b0aa::/48"); + var expected = new CidrBlock.IpV6CidrBlock( + new BigInteger(hex.parseHex("20:01:0d:b8:b0:aa:00:00:00:00:00:00:00:00:00:00")), + new BigInteger(hex.parseHex("FF:FF:FF:FF:FF:FF:00:00:00:00:00:00:00:00:00:00")), + 48 + ); + + assertThat(actual).isInstanceOf(CidrBlock.IpV6CidrBlock.class); + assertThat(actual).isEqualTo(expected); + } + + @Test + public void testIPv4InBlock() throws UnknownHostException { + var block = CidrBlock.parseCidrBlock("255.32.15.0/24"); + + assertThat(block.ipInBlock(InetAddress.getByName("255.32.15.123"))).isTrue(); + assertThat(block.ipInBlock(InetAddress.getByName("255.32.15.0"))).isTrue(); + assertThat(block.ipInBlock(InetAddress.getByName("255.32.16.0"))).isFalse(); + assertThat(block.ipInBlock(InetAddress.getByName("255.33.15.0"))).isFalse(); + assertThat(block.ipInBlock(InetAddress.getByName("254.33.15.0"))).isFalse(); + } + + @Test + public void testIPv6InBlock() throws UnknownHostException { + var block = CidrBlock.parseCidrBlock("2001:db8:b0aa::/48"); + + assertThat(block.ipInBlock(InetAddress.getByName("2001:db8:b0aa:1:1::"))).isTrue(); + assertThat(block.ipInBlock(InetAddress.getByName("2001:db8:b0aa:0:0::"))).isTrue(); + assertThat(block.ipInBlock(InetAddress.getByName("2001:db8:b0ab:1:1::"))).isFalse(); + assertThat(block.ipInBlock(InetAddress.getByName("2001:da8:b0aa:1:1::"))).isFalse(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/DynamicConfigTurnRouterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/DynamicConfigTurnRouterTest.java new file mode 100644 index 00000000..9465637d --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/DynamicConfigTurnRouterTest.java @@ -0,0 +1,137 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import com.fasterxml.jackson.core.JsonProcessingException; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class DynamicConfigTurnRouterTest { + @Test + public void testAlwaysSelectFirst() throws JsonProcessingException { + final String configString = """ + captcha: + scoreFloor: 1.0 + turn: + uriConfigs: + - uris: + - always1.org + - always2.org + - uris: + - never.org + weight: 0 + """; + DynamicConfiguration config = DynamicConfigurationManager + .parseConfiguration(configString, DynamicConfiguration.class) + .orElseThrow(); + + @SuppressWarnings("unchecked") + DynamicConfigurationManager mockDynamicConfigManager = mock( + DynamicConfigurationManager.class); + + when(mockDynamicConfigManager.getConfiguration()).thenReturn(config); + + final DynamicConfigTurnRouter configTurnRouter = new DynamicConfigTurnRouter(mockDynamicConfigManager); + + final long COUNT = 1000; + + final Map urlCounts = Stream + .generate(configTurnRouter::randomUrls) + .limit(COUNT) + .flatMap(Collection::stream) + .collect(Collectors.groupingBy(i -> i, Collectors.counting())); + + assertThat(urlCounts.get("always1.org")).isEqualTo(COUNT); + assertThat(urlCounts.get("always2.org")).isEqualTo(COUNT); + assertThat(urlCounts).doesNotContainKey("never.org"); + } + + @Test + public void testProbabilisticUrls() throws JsonProcessingException { + final String configString = """ + captcha: + scoreFloor: 1.0 + turn: + uriConfigs: + - uris: + - always.org + - sometimes1.org + weight: 5 + - uris: + - always.org + - sometimes2.org + weight: 5 + """; + DynamicConfiguration config = DynamicConfigurationManager + .parseConfiguration(configString, DynamicConfiguration.class) + .orElseThrow(); + + @SuppressWarnings("unchecked") + DynamicConfigurationManager mockDynamicConfigManager = mock( + DynamicConfigurationManager.class); + + when(mockDynamicConfigManager.getConfiguration()).thenReturn(config); + final DynamicConfigTurnRouter configTurnRouter = new DynamicConfigTurnRouter(mockDynamicConfigManager); + + final long COUNT = 1000; + + final Map urlCounts = Stream + .generate(configTurnRouter::randomUrls) + .limit(COUNT) + .flatMap(Collection::stream) + .collect(Collectors.groupingBy(i -> i, Collectors.counting())); + + assertThat(urlCounts.get("always.org")).isEqualTo(COUNT); + assertThat(urlCounts.get("sometimes1.org")).isGreaterThan(0); + assertThat(urlCounts.get("sometimes2.org")).isGreaterThan(0); + } + + @Test + public void testExplicitEnrollment() throws JsonProcessingException { + final String configString = """ + captcha: + scoreFloor: 1.0 + turn: + secret: bloop + uriConfigs: + - uris: + - enrolled.org + weight: 0 + enrolledAcis: + - 732506d7-d04f-43a4-b1d7-8a3a91ebe8a6 + - uris: + - unenrolled.org + weight: 1 + """; + DynamicConfiguration config = DynamicConfigurationManager + .parseConfiguration(configString, DynamicConfiguration.class) + .orElseThrow(); + + @SuppressWarnings("unchecked") + DynamicConfigurationManager mockDynamicConfigManager = mock( + DynamicConfigurationManager.class); + + when(mockDynamicConfigManager.getConfiguration()).thenReturn(config); + final DynamicConfigTurnRouter configTurnRouter = new DynamicConfigTurnRouter(mockDynamicConfigManager); + + List urls = configTurnRouter.targetedUrls(UUID.fromString("732506d7-d04f-43a4-b1d7-8a3a91ebe8a6")); + assertThat(urls.getFirst()).isEqualTo("enrolled.org"); + urls = configTurnRouter.targetedUrls(UUID.randomUUID()); + assertTrue(urls.isEmpty()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/TurnCallRouterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/TurnCallRouterTest.java new file mode 100644 index 00000000..56323fde --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/calls/routing/TurnCallRouterTest.java @@ -0,0 +1,262 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.calls.routing; + +import com.maxmind.geoip2.DatabaseReader; +import com.maxmind.geoip2.exception.GeoIp2Exception; +import com.maxmind.geoip2.model.CityResponse; +import com.maxmind.geoip2.record.Continent; +import com.maxmind.geoip2.record.Country; +import com.maxmind.geoip2.record.Subdivision; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TurnCallRouterTest { + + private final static String TEST_HOSTNAME = "subdomain.example.org"; + private final static List TEST_URLS_WITH_HOSTS = List.of( + "one.example.com", + "two.example.com", + "three.example.com" + ); + + private CallRoutingTable performanceTable; + private CallRoutingTable manualTable; + private DynamicConfigTurnRouter configTurnRouter; + private DatabaseReader geoIp; + private Country country; + private Continent continent; + private CallDnsRecords callDnsRecords; + private Subdivision subdivision; + private UUID aci = UUID.randomUUID(); + + @BeforeEach + void setup() throws IOException, GeoIp2Exception { + performanceTable = mock(CallRoutingTable.class); + manualTable = mock(CallRoutingTable.class); + configTurnRouter = mock(DynamicConfigTurnRouter.class); + geoIp = mock(DatabaseReader.class); + continent = mock(Continent.class); + country = mock(Country.class); + subdivision = mock(Subdivision.class); + ArrayList subdivisions = new ArrayList<>(); + subdivisions.add(subdivision); + + when(geoIp.city(any())).thenReturn(new CityResponse(null, continent, country, null, null, null, null, null, subdivisions, null)); + setupDefault(); + } + + void setupDefault() { + when(configTurnRouter.targetedUrls(any())).thenReturn(Collections.emptyList()); + when(configTurnRouter.randomUrls()).thenReturn(TEST_URLS_WITH_HOSTS); + when(configTurnRouter.getHostname()).thenReturn(TEST_HOSTNAME); + when(configTurnRouter.shouldRandomize()).thenReturn(false); + when(manualTable.getDatacentersFor(any(), any(), any(), any())).thenReturn(Collections.emptyList()); + when(continent.getCode()).thenReturn("NA"); + when(country.getIsoCode()).thenReturn("US"); + when(subdivision.getIsoCode()).thenReturn("VA"); + try { + callDnsRecords = new CallDnsRecords( + Map.of( + "dc-manual", List.of(InetAddress.getByName("1.1.1.1")), + "dc-performance1", List.of( + InetAddress.getByName("9.9.9.1"), + InetAddress.getByName("9.9.9.2") + ), + "dc-performance2", List.of(InetAddress.getByName("9.9.9.3")), + "dc-performance3", List.of(InetAddress.getByName("9.9.9.4")) + ), + Map.of( + "dc-manual", List.of(InetAddress.getByName("2222:1111:0:dead::")), + "dc-performance1", List.of( + InetAddress.getByName("2222:1111:0:abc0::"), + InetAddress.getByName("2222:1111:0:abc1::") + ), + "dc-performance2", List.of(InetAddress.getByName("2222:1111:0:abc2::")), + "dc-performance3", List.of(InetAddress.getByName("2222:1111:0:abc3::")) + ) + ); + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + } + + private TurnCallRouter router() { + return new TurnCallRouter( + () -> callDnsRecords, + () -> performanceTable, + () -> manualTable, + configTurnRouter, + () -> geoIp + ); + } + + TurnServerOptions optionsWithUrls(List urls) { + return new TurnServerOptions( + TEST_HOSTNAME, + urls, + TEST_URLS_WITH_HOSTS + ); + } + + @Test + public void testPrioritizesTargetedUrls() throws UnknownHostException { + List targetedUrls = List.of( + "targeted1.example.com", + "targeted.example.com" + ); + when(configTurnRouter.targetedUrls(any())) + .thenReturn(targetedUrls); + + assertThat(router().getRoutingFor(aci, Optional.of(InetAddress.getByName("0.0.0.1")), 10)) + .isEqualTo(new TurnServerOptions( + TEST_HOSTNAME, + null, + targetedUrls + )); + } + + @Test + public void testRandomizes() throws UnknownHostException { + when(configTurnRouter.shouldRandomize()) + .thenReturn(true); + + assertThat(router().getRoutingFor(aci, Optional.of(InetAddress.getByName("0.0.0.1")), 10)) + .isEqualTo(optionsWithUrls(null)); + } + + @Test + public void testOrderedByPerformance() throws UnknownHostException { + when(performanceTable.getDatacentersFor(any(), any(), any(), any())) + .thenReturn(List.of("dc-performance2", "dc-performance1")); + + assertThat(router().getRoutingFor(aci, Optional.of(InetAddress.getByName("0.0.0.1")), 10)) + .isEqualTo(optionsWithUrls(List.of( + "stun:9.9.9.3", + "turn:9.9.9.3", + "turn:9.9.9.3:80?transport=tcp", + "turns:9.9.9.3:443?transport=tcp", + + "stun:9.9.9.1", + "turn:9.9.9.1", + "turn:9.9.9.1:80?transport=tcp", + "turns:9.9.9.1:443?transport=tcp", + + "stun:9.9.9.2", + "turn:9.9.9.2", + "turn:9.9.9.2:80?transport=tcp", + "turns:9.9.9.2:443?transport=tcp", + + "stun:2222:1111:0:abc2:0:0:0:0", + "turn:2222:1111:0:abc2:0:0:0:0", + "turn:2222:1111:0:abc2:0:0:0:0:80?transport=tcp", + "turns:2222:1111:0:abc2:0:0:0:0:443?transport=tcp", + + "stun:2222:1111:0:abc0:0:0:0:0", + "turn:2222:1111:0:abc0:0:0:0:0", + "turn:2222:1111:0:abc0:0:0:0:0:80?transport=tcp", + "turns:2222:1111:0:abc0:0:0:0:0:443?transport=tcp", + + "stun:2222:1111:0:abc1:0:0:0:0", + "turn:2222:1111:0:abc1:0:0:0:0", + "turn:2222:1111:0:abc1:0:0:0:0:80?transport=tcp", + "turns:2222:1111:0:abc1:0:0:0:0:443?transport=tcp" + ))); + } + + @Test + public void testPrioritizesManualRecords() throws UnknownHostException { + when(performanceTable.getDatacentersFor(any(), any(), any(), any())) + .thenReturn(List.of("dc-performance1")); + when(manualTable.getDatacentersFor(any(), any(), any(), any())) + .thenReturn(List.of("dc-manual")); + + assertThat(router().getRoutingFor(aci, Optional.of(InetAddress.getByName("0.0.0.1")), 10)) + .isEqualTo(optionsWithUrls(List.of( + "stun:1.1.1.1", + "turn:1.1.1.1", + "turn:1.1.1.1:80?transport=tcp", + "turns:1.1.1.1:443?transport=tcp", + + "stun:2222:1111:0:dead:0:0:0:0", + "turn:2222:1111:0:dead:0:0:0:0", + "turn:2222:1111:0:dead:0:0:0:0:80?transport=tcp", + "turns:2222:1111:0:dead:0:0:0:0:443?transport=tcp" + ))); + } + + @Test + public void testLimitReturnsHalfIpv4AndPrioritizesPerformance() throws UnknownHostException { + when(performanceTable.getDatacentersFor(any(), any(), any(), any())) + .thenReturn(List.of("dc-performance3", "dc-performance2", "dc-performance1")); + + assertThat(router().getRoutingFor(aci, Optional.of(InetAddress.getByName("0.0.0.1")), 6)) + .isEqualTo(optionsWithUrls(List.of( + "stun:9.9.9.4", + "turn:9.9.9.4", + "turn:9.9.9.4:80?transport=tcp", + "turns:9.9.9.4:443?transport=tcp", + + "stun:9.9.9.3", + "turn:9.9.9.3", + "turn:9.9.9.3:80?transport=tcp", + "turns:9.9.9.3:443?transport=tcp", + + "stun:9.9.9.1", + "turn:9.9.9.1", + "turn:9.9.9.1:80?transport=tcp", + "turns:9.9.9.1:443?transport=tcp", + + "stun:2222:1111:0:abc3:0:0:0:0", + "turn:2222:1111:0:abc3:0:0:0:0", + "turn:2222:1111:0:abc3:0:0:0:0:80?transport=tcp", + "turns:2222:1111:0:abc3:0:0:0:0:443?transport=tcp", + + "stun:2222:1111:0:abc2:0:0:0:0", + "turn:2222:1111:0:abc2:0:0:0:0", + "turn:2222:1111:0:abc2:0:0:0:0:80?transport=tcp", + "turns:2222:1111:0:abc2:0:0:0:0:443?transport=tcp", + + "stun:2222:1111:0:abc0:0:0:0:0", + "turn:2222:1111:0:abc0:0:0:0:0", + "turn:2222:1111:0:abc0:0:0:0:0:80?transport=tcp", + "turns:2222:1111:0:abc0:0:0:0:0:443?transport=tcp" + ))); + } + + @Test + public void testNoDatacentersMatched() throws UnknownHostException { + when(performanceTable.getDatacentersFor(any(), any(), any(), any())) + .thenReturn(List.of()); + + assertThat(router().getRoutingFor(aci, Optional.of(InetAddress.getByName("0.0.0.1")), 10)) + .isEqualTo(optionsWithUrls(List.of())); + } + + @Test + public void testHandlesDatacenterNotInDnsRecords() throws UnknownHostException { + when(performanceTable.getDatacentersFor(any(), any(), any(), any())) + .thenReturn(List.of("unsynced-datacenter")); + + assertThat(router().getRoutingFor(aci, Optional.of(InetAddress.getByName("0.0.0.1")), 10)) + .isEqualTo(optionsWithUrls(List.of())); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfigurationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfigurationTest.java index fe698b1b..6132e2a3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfigurationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfigurationTest.java @@ -333,6 +333,8 @@ class DynamicConfigurationTest { weight: 2 enrolledAcis: - 732506d7-d04f-43a4-b1d7-8a3a91ebe8a6 + randomizeRate: 100_000 + hostname: test.domain.org """); DynamicTurnConfiguration turnConfiguration = DynamicConfigurationManager .parseConfiguration(config, DynamicConfiguration.class) @@ -345,6 +347,8 @@ class DynamicConfigurationTest { assertThat(turnConfiguration.getUriConfigs().get(1).getEnrolledAcis()) .containsExactly(UUID.fromString("732506d7-d04f-43a4-b1d7-8a3a91ebe8a6")); + assertThat(turnConfiguration.getHostname()).isEqualTo("test.domain.org"); + assertThat(turnConfiguration.getRandomizeRate()).isEqualTo(100_000L); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerTest.java new file mode 100644 index 00000000..8565cf36 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerTest.java @@ -0,0 +1,142 @@ +/* + * Copyright 2013 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.controllers; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.net.HttpHeaders; +import io.dropwizard.auth.AuthValueFactoryProvider; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import io.dropwizard.testing.junit5.ResourceExtension; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Optional; +import javax.ws.rs.core.Response; +import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.TurnToken; +import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; +import org.whispersystems.textsecuregcm.calls.routing.TurnCallRouter; +import org.whispersystems.textsecuregcm.calls.routing.TurnServerOptions; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider; + +@ExtendWith(DropwizardExtensionsSupport.class) +class CallRoutingControllerTest { + private static final RateLimiters rateLimiters = mock(RateLimiters.class); + private static final RateLimiter getCallEndpointLimiter = mock(RateLimiter.class); + private static final DynamicConfigurationManager configManager = mock(DynamicConfigurationManager.class); + private static final TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(configManager, "bloop".getBytes( + StandardCharsets.UTF_8)); + private static final TurnCallRouter turnCallRouter = mock(TurnCallRouter.class); + private static final String GET_CALL_ENDPOINTS_PATH = "v1/calling/relays"; + private static final String REMOTE_ADDRESS = "123.123.123.1"; + + private static final ResourceExtension resources = ResourceExtension.builder() + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedAccount.class)) + .addProvider(new RateLimitExceededExceptionMapper()) + .addProvider(new TestRemoteAddressFilterProvider(REMOTE_ADDRESS)) + .setMapper(SystemMapper.jsonMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new CallRoutingController(rateLimiters, turnCallRouter, turnTokenGenerator)) + .build(); + + @BeforeEach + void setup() { + when(rateLimiters.getCallEndpointLimiter()).thenReturn(getCallEndpointLimiter); + } + + @Test + void testGetTurnEndpointsSuccess() throws UnknownHostException { + TurnServerOptions options = new TurnServerOptions( + "example.domain.org", + List.of("stun:12.34.56.78"), + List.of("stun:example.domain.org") + ); + + when(turnCallRouter.getRoutingFor( + eq(AuthHelper.VALID_UUID), + eq(Optional.of(InetAddress.getByName(REMOTE_ADDRESS))), + anyInt()) + ).thenReturn(options); + try(Response response = resources.getJerseyTest() + .target(GET_CALL_ENDPOINTS_PATH) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertThat(response.getStatus()).isEqualTo(200); + TurnToken token = response.readEntity(TurnToken.class); + assertThat(token.username()).isNotEmpty(); + assertThat(token.password()).isNotEmpty(); + assertThat(token.hostname()).isEqualTo(options.hostname()); + assertThat(token.urlsWithIps()).isEqualTo(options.urlsWithIps()); + assertThat(token.urls()).isEqualTo(options.urlsWithHostname()); + } + } + + @Test + void testGetTurnEndpointsInvalidIpSuccess() throws UnknownHostException { + TurnServerOptions options = new TurnServerOptions( + "example.domain.org", + List.of(), + List.of("stun:example.domain.org") + ); + + when(turnCallRouter.getRoutingFor( + eq(AuthHelper.VALID_UUID), + eq(Optional.of(InetAddress.getByName(REMOTE_ADDRESS))), + anyInt()) + ).thenReturn(options); + try(Response response = resources.getJerseyTest() + .target(GET_CALL_ENDPOINTS_PATH) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertThat(response.getStatus()).isEqualTo(200); + TurnToken token = response.readEntity(TurnToken.class); + assertThat(token.username()).isNotEmpty(); + assertThat(token.password()).isNotEmpty(); + assertThat(token.hostname()).isEqualTo(options.hostname()); + assertThat(token.urlsWithIps()).isEqualTo(options.urlsWithIps()); + assertThat(token.urls()).isEqualTo(options.urlsWithHostname()); + } + } + + @Test + void testGetTurnEndpointRateLimited() throws RateLimitExceededException { + doThrow(new RateLimitExceededException(null, false)) + .when(getCallEndpointLimiter).validate(AuthHelper.VALID_UUID); + + try(final Response response = resources.getJerseyTest() + .target(GET_CALL_ENDPOINTS_PATH) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertThat(response.getStatus()).isEqualTo(429); + } + } +}