0
0
mirror of https://github.com/signalapp/Signal-Server.git synced 2024-09-19 19:42:18 +02:00

Enable header-based auth for WebSocket connections

This commit is contained in:
Sergey Skrobotov 2023-09-25 11:28:23 -07:00
parent a263611746
commit d0fdae3df7
8 changed files with 147 additions and 85 deletions

View File

@ -18,6 +18,7 @@ import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
/**
* A basic credential authentication interceptor enforces the presence of a valid username and password on every call.
@ -39,7 +40,7 @@ public class BasicCredentialAuthenticationInterceptor implements ServerIntercept
@VisibleForTesting
static final Metadata.Key<String> BASIC_CREDENTIALS =
Metadata.Key.of("x-signal-basic-auth-credentials", Metadata.ASCII_STRING_MARSHALLER);
Metadata.Key.of("x-signal-auth", Metadata.ASCII_STRING_MARSHALLER);
private static final Metadata EMPTY_TRAILERS = new Metadata();
@ -48,17 +49,20 @@ public class BasicCredentialAuthenticationInterceptor implements ServerIntercept
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
final String credentialString = headers.get(BASIC_CREDENTIALS);
final String authHeader = headers.get(BASIC_CREDENTIALS);
if (StringUtils.isNotBlank(credentialString)) {
try {
final BasicCredentials credentials = extractBasicCredentials(credentialString);
if (StringUtils.isNotBlank(authHeader)) {
final Optional<BasicCredentials> maybeCredentials = HeaderUtils.basicCredentialsFromAuthHeader(authHeader);
if (maybeCredentials.isEmpty()) {
call.close(Status.UNAUTHENTICATED.withDescription("Could not parse credentials"), EMPTY_TRAILERS);
} else {
final Optional<AuthenticatedAccount> maybeAuthenticatedAccount =
baseAccountAuthenticator.authenticate(credentials, false);
baseAccountAuthenticator.authenticate(maybeCredentials.get(), false);
if (maybeAuthenticatedAccount.isPresent()) {
final AuthenticatedAccount authenticatedAccount = maybeAuthenticatedAccount.get();
@ -71,8 +75,6 @@ public class BasicCredentialAuthenticationInterceptor implements ServerIntercept
} else {
call.close(Status.UNAUTHENTICATED.withDescription("Credentials not accepted"), EMPTY_TRAILERS);
}
} catch (final IllegalArgumentException e) {
call.close(Status.UNAUTHENTICATED.withDescription("Could not parse credentials"), EMPTY_TRAILERS);
}
} else {
call.close(Status.UNAUTHENTICATED.withDescription("No credentials provided"), EMPTY_TRAILERS);
@ -80,15 +82,4 @@ public class BasicCredentialAuthenticationInterceptor implements ServerIntercept
return new ServerCall.Listener<>() {};
}
@VisibleForTesting
static BasicCredentials extractBasicCredentials(final String credentials) {
if (credentials.indexOf(':') < 0) {
throw new IllegalArgumentException("Credentials do not include a username and password part");
}
final String[] pieces = credentials.split(":", 2);
return new BasicCredentials(pieces[0], pieces[1]);
}
}

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.util;
import static java.util.Objects.requireNonNull;
import io.dropwizard.auth.basic.BasicCredentials;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Optional;
@ -63,4 +64,38 @@ public final class HeaderUtils {
})
.filter(StringUtils::isNotBlank);
}
/**
* Parses a Base64-encoded value of the `Authorization` header
* in the form of `Basic dXNlcm5hbWU6cGFzc3dvcmQ=`.
* Note: parsing logic is copied from {@link io.dropwizard.auth.basic.BasicCredentialAuthFilter#getCredentials(String)}.
*/
public static Optional<BasicCredentials> basicCredentialsFromAuthHeader(final String authHeader) {
final int space = authHeader.indexOf(' ');
if (space <= 0) {
return Optional.empty();
}
final String method = authHeader.substring(0, space);
if (!"Basic".equalsIgnoreCase(method)) {
return Optional.empty();
}
final String decoded;
try {
decoded = new String(Base64.getDecoder().decode(authHeader.substring(space + 1)), StandardCharsets.UTF_8);
} catch (IllegalArgumentException e) {
return Optional.empty();
}
// Decoded credentials is 'username:password'
final int i = decoded.indexOf(':');
if (i <= 0) {
return Optional.empty();
}
final String username = decoded.substring(0, i);
final String password = decoded.substring(i + 1);
return Optional.of(new BasicCredentials(username, password));
}
}

View File

@ -1,14 +1,18 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.websocket;
import static org.whispersystems.textsecuregcm.util.HeaderUtils.basicCredentialsFromAuthHeader;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.basic.BasicCredentials;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
@ -18,29 +22,32 @@ import org.whispersystems.websocket.auth.WebSocketAuthenticator;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedAccount> {
private static final AuthenticationResult<AuthenticatedAccount> CREDENTIALS_NOT_PRESENTED =
new AuthenticationResult<>(Optional.empty(), false);
private static final AuthenticationResult<AuthenticatedAccount> INVALID_CREDENTIALS_PRESENTED =
new AuthenticationResult<>(Optional.empty(), true);
private final AccountAuthenticator accountAuthenticator;
public WebSocketAccountAuthenticator(AccountAuthenticator accountAuthenticator) {
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator) {
this.accountAuthenticator = accountAuthenticator;
}
@Override
public AuthenticationResult<AuthenticatedAccount> authenticate(UpgradeRequest request)
public AuthenticationResult<AuthenticatedAccount> authenticate(final UpgradeRequest request)
throws AuthenticationException {
Map<String, List<String>> parameters = request.getParameterMap();
List<String> usernames = parameters.get("login");
List<String> passwords = parameters.get("password");
if (usernames == null || usernames.size() == 0 ||
passwords == null || passwords.size() == 0) {
return new AuthenticationResult<>(Optional.empty(), false);
}
BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"),
passwords.get(0).replace(" ", "+"));
try {
return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true);
final AuthenticationResult<AuthenticatedAccount> authResultFromHeader =
authenticatedAccountFromHeaderAuth(request.getHeader(HttpHeaders.AUTHORIZATION));
// the logic here is that if the `Authorization` header was set for the request,
// it takes the priority and we use the result of the header-based auth
// ignoring the result of the query-based auth.
if (authResultFromHeader.credentialsPresented()) {
return authResultFromHeader;
}
return authenticatedAccountFromQueryParams(request);
} catch (final Exception e) {
// this will be handled and logged upstream
// the most likely exception is a transient error connecting to account storage
@ -48,4 +55,26 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
}
}
private AuthenticationResult<AuthenticatedAccount> authenticatedAccountFromQueryParams(final UpgradeRequest request) {
final Map<String, List<String>> parameters = request.getParameterMap();
final List<String> usernames = parameters.get("login");
final List<String> passwords = parameters.get("password");
if (usernames == null || usernames.size() == 0 ||
passwords == null || passwords.size() == 0) {
return CREDENTIALS_NOT_PRESENTED;
}
final BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"),
passwords.get(0).replace(" ", "+"));
return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true);
}
private AuthenticationResult<AuthenticatedAccount> authenticatedAccountFromHeaderAuth(@Nullable final String authHeader)
throws AuthenticationException {
if (authHeader == null) {
return CREDENTIALS_NOT_PRESENTED;
}
return basicCredentialsFromAuthHeader(authHeader)
.map(credentials -> new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true))
.orElse(INVALID_CREDENTIALS_PRESENTED);
}
}

View File

@ -13,7 +13,6 @@ import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.auth.basic.BasicCredentials;
import io.grpc.CallCredentials;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
@ -30,7 +29,6 @@ import java.util.stream.Stream;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
@ -41,6 +39,7 @@ import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
class BasicCredentialAuthenticationInterceptorTest {
@ -122,8 +121,10 @@ class BasicCredentialAuthenticationInterceptorTest {
malformedCredentialHeaders.put(BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS, "Incorrect");
final Metadata structurallyValidCredentialHeaders = new Metadata();
structurallyValidCredentialHeaders.put(BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS,
UUID.randomUUID() + ":" + RandomStringUtils.randomAlphanumeric(16));
structurallyValidCredentialHeaders.put(
BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS,
HeaderUtils.basicAuthHeader(UUID.randomUUID().toString(), RandomStringUtils.randomAlphanumeric(16))
);
return Stream.of(
Arguments.of(new Metadata(), true, false),
@ -132,22 +133,4 @@ class BasicCredentialAuthenticationInterceptorTest {
Arguments.of(structurallyValidCredentialHeaders, true, true)
);
}
@Test
void extractBasicCredentials() {
final String username = UUID.randomUUID().toString();
final String password = RandomStringUtils.random(16);
final BasicCredentials basicCredentials =
BasicCredentialAuthenticationInterceptor.extractBasicCredentials(username + ":" + password);
assertEquals(username, basicCredentials.getUsername());
assertEquals(password, basicCredentials.getPassword());
}
@Test
void extractBasicCredentialsIllegalArgument() {
assertThrows(IllegalArgumentException.class,
() -> BasicCredentialAuthenticationInterceptor.extractBasicCredentials("This does not include a password"));
}
}

View File

@ -6,17 +6,18 @@
package org.whispersystems.textsecuregcm.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.dropwizard.auth.basic.BasicCredentials;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
@ -26,6 +27,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
@ -33,9 +35,12 @@ class WebSocketAccountAuthenticatorTest {
private static final String VALID_USER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("NZ"), PhoneNumberUtil.PhoneNumberFormat.E164);
private static final String VALID_PASSWORD = "valid";
private static final String INVALID_USER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("AU"), PhoneNumberUtil.PhoneNumberFormat.E164);
private static final String INVALID_PASSWORD = "invalid";
private AccountAuthenticator accountAuthenticator;
@ -57,10 +62,16 @@ class WebSocketAccountAuthenticatorTest {
@ParameterizedTest
@MethodSource
void testAuthenticate(final Map<String, List<String>> upgradeRequestParameters, final boolean expectAccount,
final boolean expectRequired) throws Exception {
void testAuthenticate(
@Nullable final String authorizationHeaderValue,
final Map<String, List<String>> upgradeRequestParameters,
final boolean expectAccount,
final boolean expectCredentialsPresented) throws Exception {
when(upgradeRequest.getParameterMap()).thenReturn(upgradeRequestParameters);
if (authorizationHeaderValue != null) {
when(upgradeRequest.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn(authorizationHeaderValue);
}
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(
accountAuthenticator);
@ -68,20 +79,34 @@ class WebSocketAccountAuthenticatorTest {
final WebSocketAuthenticator.AuthenticationResult<AuthenticatedAccount> result = webSocketAuthenticator.authenticate(
upgradeRequest);
if (expectAccount) {
assertTrue(result.getUser().isPresent());
} else {
assertTrue(result.getUser().isEmpty());
}
assertEquals(expectRequired, result.isRequired());
assertEquals(expectAccount, result.getUser().isPresent());
assertEquals(expectCredentialsPresented, result.credentialsPresented());
}
private static Stream<Arguments> testAuthenticate() {
final Map<String, List<String>> paramsMapWithValidAuth =
Map.of("login", List.of(VALID_USER), "password", List.of(VALID_PASSWORD));
final Map<String, List<String>> paramsMapWithInvalidAuth =
Map.of("login", List.of(INVALID_USER), "password", List.of(INVALID_PASSWORD));
final String headerWithValidAuth =
HeaderUtils.basicAuthHeader(VALID_USER, VALID_PASSWORD);
final String headerWithInvalidAuth =
HeaderUtils.basicAuthHeader(INVALID_USER, INVALID_PASSWORD);
return Stream.of(
Arguments.of(Map.of("login", List.of(VALID_USER), "password", List.of(VALID_PASSWORD)), true, true),
Arguments.of(Map.of("login", List.of(INVALID_USER), "password", List.of(INVALID_PASSWORD)), false, true),
Arguments.of(Map.of(), false, false)
// if `Authorization` header is present, outcome should not depend on the value of query parameters
Arguments.of(headerWithValidAuth, Map.of(), true, true),
Arguments.of(headerWithInvalidAuth, Map.of(), false, true),
Arguments.of("invalid header value", Map.of(), false, true),
Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, true),
Arguments.of(headerWithInvalidAuth, paramsMapWithValidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithValidAuth, false, true),
Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, true),
Arguments.of(headerWithInvalidAuth, paramsMapWithInvalidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithInvalidAuth, false, true),
// if `Authorization` header is not set, outcome should match the query params based auth
Arguments.of(null, paramsMapWithValidAuth, true, true),
Arguments.of(null, paramsMapWithInvalidAuth, false, true),
Arguments.of(null, Map.of(), false, false)
);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -142,7 +142,7 @@ class WebSocketConnectionTest {
when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.getUser().isPresent());
assertFalse(account.isRequired());
assertFalse(account.credentialsPresented());
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener(

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket;
@ -57,7 +57,7 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends WebSo
if (authenticator.isPresent()) {
AuthenticationResult<T> authenticationResult = authenticator.get().authenticate(request);
if (authenticationResult.getUser().isEmpty() && authenticationResult.isRequired()) {
if (authenticationResult.getUser().isEmpty() && authenticationResult.credentialsPresented()) {
response.sendForbidden("Unauthorized");
return null;
} else {

View File

@ -1,33 +1,32 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import java.security.Principal;
import java.util.Optional;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
public interface WebSocketAuthenticator<T extends Principal> {
AuthenticationResult<T> authenticate(UpgradeRequest request) throws AuthenticationException;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class AuthenticationResult<T> {
class AuthenticationResult<T> {
private final Optional<T> user;
private final boolean required;
private final boolean credentialsPresented;
public AuthenticationResult(Optional<T> user, boolean required) {
this.user = user;
this.required = required;
public AuthenticationResult(final Optional<T> user, final boolean credentialsPresented) {
this.user = user;
this.credentialsPresented = credentialsPresented;
}
public Optional<T> getUser() {
return user;
}
public boolean isRequired() {
return required;
public boolean credentialsPresented() {
return credentialsPresented;
}
}
}