diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java index 82e58fda..1d3d45ad 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java @@ -12,6 +12,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import com.stripe.exception.StripeException; import com.stripe.model.Charge; import com.stripe.model.Charge.Outcome; import com.stripe.model.Invoice; @@ -36,6 +37,7 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.stream.Collectors; import javax.annotation.Nonnull; import javax.crypto.Mac; @@ -294,6 +296,7 @@ public class SubscriptionController { public enum Type { UNSUPPORTED_LEVEL, UNSUPPORTED_CURRENCY, + PAYMENT_REQUIRES_ACTION, } private final Type type; @@ -374,6 +377,22 @@ public class SubscriptionController { // retries this request stripeManager.createSubscription(processorCustomer.customerId(), priceConfiguration.getId(), level, lastSubscriptionCreatedAt) + .exceptionally(e -> { + if (e.getCause() instanceof StripeException stripeException + && stripeException.getCode().equals("subscription_payment_intent_requires_action")) { + throw new BadRequestException(Response.status(Status.BAD_REQUEST) + .entity(new SetSubscriptionLevelErrorResponse(List.of( + new SetSubscriptionLevelErrorResponse.Error( + SetSubscriptionLevelErrorResponse.Error.Type.PAYMENT_REQUIRES_ACTION, null + ) + ))).build()); + } + if (e instanceof RuntimeException re) { + throw re; + } + + throw new CompletionException(e); + }) .thenCompose(subscription -> subscriptionManager.subscriptionCreated( requestData.subscriberUser, subscription.getId(), requestData.now, level) .thenApply(unused -> subscription))) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/StripeManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/StripeManager.java index 5364f931..4e8b5153 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/StripeManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/subscriptions/StripeManager.java @@ -211,6 +211,7 @@ public class StripeManager implements SubscriptionProcessorManager { SubscriptionCreateParams params = SubscriptionCreateParams.builder() .setCustomer(customerId) .setOffSession(true) + .setPaymentBehavior(SubscriptionCreateParams.PaymentBehavior.ERROR_IF_INCOMPLETE) .addItem(SubscriptionCreateParams.Item.builder() .setPrice(priceId) .build()) @@ -250,6 +251,7 @@ public class StripeManager implements SubscriptionProcessorManager { .setProrationBehavior(ProrationBehavior.NONE) .setBillingCycleAnchor(BillingCycleAnchor.NOW) .setOffSession(true) + .setPaymentBehavior(SubscriptionUpdateParams.PaymentBehavior.ERROR_IF_INCOMPLETE) .addAllItem(items) .build(); try { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/SubscriptionControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/SubscriptionControllerTest.java index fa6c383a..de7922d2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/SubscriptionControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/SubscriptionControllerTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.controllers; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; @@ -14,6 +15,8 @@ import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.util.AttributeValues.b; import static org.whispersystems.textsecuregcm.util.AttributeValues.n; +import com.stripe.exception.ApiException; +import com.stripe.model.Subscription; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; @@ -26,13 +29,16 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import javax.ws.rs.client.Entity; import javax.ws.rs.core.Response; import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations; @@ -111,6 +117,115 @@ class SubscriptionControllerTest { assertThat(response.getStatus()).isEqualTo(422); } + @Nested + class SetSubscriptionLevel { + + private final long levelId = 5L; + private final String currency = "eur"; + + private String subscriberId; + + @BeforeEach + void setUp() { + when(CLOCK.instant()).thenReturn(Instant.now()); + + final byte[] subscriberUserAndKey = new byte[32]; + Arrays.fill(subscriberUserAndKey, (byte) 1); + subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey); + + final ProcessorCustomer processorCustomer = new ProcessorCustomer("testCustomerId", SubscriptionProcessor.STRIPE); + + final Map dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), + SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), + SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()), + SubscriptionManager.KEY_PROCESSOR_ID_CUSTOMER_ID, b(processorCustomer.toDynamoBytes()) + ); + final SubscriptionManager.Record record = SubscriptionManager.Record.from( + Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); + when(SUBSCRIPTION_MANAGER.get(eq(Arrays.copyOfRange(subscriberUserAndKey, 0, 16)), any())) + .thenReturn(CompletableFuture.completedFuture(SubscriptionManager.GetResult.found(record))); + + final SubscriptionLevelConfiguration levelConfig = mock(SubscriptionLevelConfiguration.class); + when(SUBSCRIPTION_CONFIG.getLevels()) + .thenReturn(Map.of(levelId, levelConfig)); + + final SubscriptionPriceConfiguration priceConfig = new SubscriptionPriceConfiguration("testPriceId", + BigDecimal.TEN); + when(levelConfig.getPrices()) + .thenReturn(Map.of(currency, priceConfig)); + + when(SUBSCRIPTION_MANAGER.subscriptionCreated(any(), any(), any(), anyLong())) + .thenReturn(CompletableFuture.completedFuture(null)); + } + + @Test + void success() { + when(STRIPE_MANAGER.createSubscription(any(), any(), anyLong(), anyLong())) + .thenReturn(CompletableFuture.completedFuture(mock(Subscription.class))); + + final String level = String.valueOf(levelId); + final String idempotencyKey = UUID.randomUUID().toString(); + final Response response = RESOURCE_EXTENSION.target( + String.format("/v1/subscription/%s/level/%s/%s/%s", subscriberId, level, currency, idempotencyKey)) + .request() + .put(Entity.json("")); + + assertThat(response.getStatus()).isEqualTo(200); + } + + @Test + void missingCustomerId() { + final byte[] subscriberUserAndKey = new byte[32]; + Arrays.fill(subscriberUserAndKey, (byte) 1); + subscriberId = Base64.getEncoder().encodeToString(subscriberUserAndKey); + + final Map dynamoItem = Map.of(SubscriptionManager.KEY_PASSWORD, b(new byte[16]), + SubscriptionManager.KEY_CREATED_AT, n(Instant.now().getEpochSecond()), + SubscriptionManager.KEY_ACCESSED_AT, n(Instant.now().getEpochSecond()) + // missing processor:customer field + ); + final SubscriptionManager.Record record = SubscriptionManager.Record.from( + Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); + when(SUBSCRIPTION_MANAGER.get(eq(Arrays.copyOfRange(subscriberUserAndKey, 0, 16)), any())) + .thenReturn(CompletableFuture.completedFuture(SubscriptionManager.GetResult.found(record))); + + final String level = String.valueOf(levelId); + final String idempotencyKey = UUID.randomUUID().toString(); + final Response response = RESOURCE_EXTENSION.target( + String.format("/v1/subscription/%s/level/%s/%s/%s", subscriberId, level, currency, idempotencyKey)) + .request() + .put(Entity.json("")); + + assertThat(response.getStatus()).isEqualTo(409); + } + + @Test + void stripePaymentIntentRequiresAction() { + final ApiException stripeException = new ApiException("Payment intent requires action", + UUID.randomUUID().toString(), "subscription_payment_intent_requires_action", 400, new Exception()); + when(STRIPE_MANAGER.createSubscription(any(), any(), anyLong(), anyLong())) + .thenReturn(CompletableFuture.failedFuture(new CompletionException(stripeException))); + + final String level = String.valueOf(levelId); + final String idempotencyKey = UUID.randomUUID().toString(); + final Response response = RESOURCE_EXTENSION.target( + String.format("/v1/subscription/%s/level/%s/%s/%s", subscriberId, level, currency, idempotencyKey)) + .request() + .put(Entity.json("")); + + assertThat(response.getStatus()).isEqualTo(400); + + assertThat(response.readEntity(SubscriptionController.SetSubscriptionLevelErrorResponse.class)) + .satisfies(errorResponse -> { + assertThat(errorResponse.getErrors()) + .anySatisfy(error -> { + assertThat(error.getType()).isEqualTo( + SubscriptionController.SetSubscriptionLevelErrorResponse.Error.Type.PAYMENT_REQUIRES_ACTION); + }); + }); + } + } + @Test void createSubscriber() { when(CLOCK.instant()).thenReturn(Instant.now()); @@ -205,8 +320,7 @@ class SubscriptionControllerTest { final SubscriptionManager.Record record = SubscriptionManager.Record.from( Arrays.copyOfRange(subscriberUserAndKey, 0, 16), dynamoItem); when(SUBSCRIPTION_MANAGER.create(any(), any(), any(Instant.class))) - .thenReturn(CompletableFuture.completedFuture( - record)); + .thenReturn(CompletableFuture.completedFuture(record)); final Response createSubscriberResponse = RESOURCE_EXTENSION .target(String.format("/v1/subscription/%s", subscriberId))