From bbf5e1fa78b62744226a0a4dd1fe0b60e672dfdc Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Thu, 11 Jun 2020 14:03:05 -0400 Subject: [PATCH] Use the UA string from websocket upgrade requests if available. --- .../MetricsRequestEventListenerTest.java | 57 +++++++++++++++++++ .../websocket/WebSocketResourceProvider.java | 8 +++ 2 files changed, 65 insertions(+) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java index 632cfc8f..983898c0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java @@ -186,6 +186,63 @@ public class MetricsRequestEventListenerTest { RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + when(request.getHeader("User-Agent")).thenReturn("Signal-Android 4.53.7 (Android 8.1)"); + + final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); + when(meterRegistry.counter(eq(MetricsRequestEventListener.COUNTER_NAME), any(Iterable.class))).thenReturn(counter); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture()); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getStatus()).isEqualTo(200); + + verify(meterRegistry).counter(eq(MetricsRequestEventListener.COUNTER_NAME), tagCaptor.capture()); + + final Iterable tagIterable = tagCaptor.getValue(); + final Set tags = new HashSet<>(); + + for (final Tag tag : tagIterable) { + tags.add(tag); + } + + assertEquals(5, tags.size()); + assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.PATH_TAG, "/v1/test/hello"))); + assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.STATUS_CODE_TAG, String.valueOf(200)))); + assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.VERSION_TAG, "4.53.7"))); + } + + @Test + public void testActualRouteMessageSuccessNoUserAgent() throws InvalidProtocolBufferException { + MetricsApplicationEventListener applicationEventListener = mock(MetricsApplicationEventListener.class); + when(applicationEventListener.onRequest(any())).thenReturn(listener); + + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(applicationEventListener); + resourceConfig.register(new TestResource()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + when(session.getUpgradeRequest()).thenReturn(request); when(session.getRemote()).thenReturn(remoteEndpoint); diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java index 6cabf7b1..4fa9f141 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java @@ -71,6 +71,7 @@ public class WebSocketResourceProvider implements WebSocket private Session session; private RemoteEndpoint remoteEndpoint; private WebSocketSessionContext context; + private String userAgent; public WebSocketResourceProvider(String remoteAddress, ApplicationHandler jerseyHandler, @@ -92,6 +93,7 @@ public class WebSocketResourceProvider implements WebSocket @Override public void onWebSocketConnect(Session session) { this.session = session; + this.userAgent = session.getUpgradeRequest().getHeader("User-Agent"); this.remoteEndpoint = session.getRemote(); this.context = new WebSocketSessionContext(new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap)); this.context.setAuthenticated(authenticated); @@ -157,6 +159,12 @@ public class WebSocketResourceProvider implements WebSocket containerRequest.header(entry.getKey(), entry.getValue()); } + final List requestUserAgentHeader = containerRequest.getRequestHeader("User-Agent"); + + if ((requestUserAgentHeader == null || requestUserAgentHeader.isEmpty()) && userAgent != null) { + containerRequest.header("User-Agent", userAgent); + } + if (requestMessage.getBody().isPresent()) { containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get())); }