0
0
mirror of https://github.com/signalapp/Signal-Server.git synced 2024-09-19 19:42:18 +02:00

Use the UA string from websocket upgrade requests if available.

This commit is contained in:
Jon Chambers 2020-06-11 14:03:05 -04:00 committed by Jon Chambers
parent 7454e55693
commit bbf5e1fa78
2 changed files with 65 additions and 0 deletions

View File

@ -186,6 +186,63 @@ public class MetricsRequestEventListenerTest {
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(request.getHeader("User-Agent")).thenReturn("Signal-Android 4.53.7 (Android 8.1)");
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(meterRegistry.counter(eq(MetricsRequestEventListener.COUNTER_NAME), any(Iterable.class))).thenReturn(counter);
provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
assertThat(response.getStatus()).isEqualTo(200);
verify(meterRegistry).counter(eq(MetricsRequestEventListener.COUNTER_NAME), tagCaptor.capture());
final Iterable<Tag> tagIterable = tagCaptor.getValue();
final Set<Tag> tags = new HashSet<>();
for (final Tag tag : tagIterable) {
tags.add(tag);
}
assertEquals(5, tags.size());
assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.PATH_TAG, "/v1/test/hello")));
assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.STATUS_CODE_TAG, String.valueOf(200))));
assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.VERSION_TAG, "4.53.7")));
}
@Test
public void testActualRouteMessageSuccessNoUserAgent() throws InvalidProtocolBufferException {
MetricsApplicationEventListener applicationEventListener = mock(MetricsApplicationEventListener.class);
when(applicationEventListener.onRequest(any())).thenReturn(listener);
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(applicationEventListener);
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);

View File

@ -71,6 +71,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
private Session session;
private RemoteEndpoint remoteEndpoint;
private WebSocketSessionContext context;
private String userAgent;
public WebSocketResourceProvider(String remoteAddress,
ApplicationHandler jerseyHandler,
@ -92,6 +93,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
@Override
public void onWebSocketConnect(Session session) {
this.session = session;
this.userAgent = session.getUpgradeRequest().getHeader("User-Agent");
this.remoteEndpoint = session.getRemote();
this.context = new WebSocketSessionContext(new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap));
this.context.setAuthenticated(authenticated);
@ -157,6 +159,12 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
containerRequest.header(entry.getKey(), entry.getValue());
}
final List<String> requestUserAgentHeader = containerRequest.getRequestHeader("User-Agent");
if ((requestUserAgentHeader == null || requestUserAgentHeader.isEmpty()) && userAgent != null) {
containerRequest.header("User-Agent", userAgent);
}
if (requestMessage.getBody().isPresent()) {
containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get()));
}