Skip to content

Commit

Permalink
chore(spanner): add multiplexed session support for batch write (#3470)
Browse files Browse the repository at this point in the history
* chore(spanner): add multiplexed session support for batch write

* chore(spanner): lint fix
  • Loading branch information
harshachinta authored Jan 29, 2025
1 parent 1f143a4 commit 9a5d86b
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@

package com.google.cloud.spanner;

import com.google.api.gax.rpc.ServerStream;
import com.google.cloud.Timestamp;
import com.google.cloud.spanner.Options.TransactionOption;
import com.google.spanner.v1.BatchWriteResponse;

/**
* Base class for the Multiplexed Session {@link DatabaseClient} implementation. Throws {@link
Expand All @@ -43,11 +40,4 @@ public String getDatabaseRole() {
public Timestamp writeAtLeastOnce(Iterable<Mutation> mutations) throws SpannerException {
return writeAtLeastOnceWithOptions(mutations).getCommitTimestamp();
}

@Override
public ServerStream<BatchWriteResponse> batchWriteAtLeastOnce(
Iterable<MutationGroup> mutationGroups, TransactionOption... options)
throws SpannerException {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ public ServerStream<BatchWriteResponse> batchWriteAtLeastOnce(
throws SpannerException {
ISpan span = tracer.spanBuilder(READ_WRITE_TRANSACTION, commonAttributes, options);
try (IScope s = tracer.withSpan(span)) {
if (canUseMultiplexedSessionsForRW() && getMultiplexedSessionDatabaseClient() != null) {
return getMultiplexedSessionDatabaseClient().batchWriteAtLeastOnce(mutationGroups, options);
}
return runWithSessionRetry(session -> session.batchWriteAtLeastOnce(mutationGroups, options));
} catch (RuntimeException e) {
span.setStatus(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@

import com.google.api.core.ApiFuture;
import com.google.api.core.ApiFutures;
import com.google.api.gax.rpc.ServerStream;
import com.google.cloud.Timestamp;
import com.google.cloud.spanner.DelayedReadContext.DelayedReadOnlyTransaction;
import com.google.cloud.spanner.MultiplexedSessionDatabaseClient.MultiplexedSessionTransaction;
import com.google.cloud.spanner.Options.TransactionOption;
import com.google.cloud.spanner.Options.UpdateOption;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.spanner.v1.BatchWriteResponse;
import java.util.concurrent.ExecutionException;

/**
Expand Down Expand Up @@ -164,6 +166,22 @@ public CommitResponse writeWithOptions(Iterable<Mutation> mutations, Transaction
}
}

/**
* This is a blocking method, as the interface that it implements is also defined as a blocking
* method.
*/
@Override
public ServerStream<BatchWriteResponse> batchWriteAtLeastOnce(
Iterable<MutationGroup> mutationGroups, TransactionOption... options)
throws SpannerException {
SessionReference sessionReference = getSessionReference();
try (MultiplexedSessionTransaction transaction =
new MultiplexedSessionTransaction(
client, span, sessionReference, NO_CHANNEL_HINT, /* singleUse = */ true)) {
return transaction.batchWriteAtLeastOnce(mutationGroups, options);
}
}

@Override
public TransactionRunner readWriteTransaction(TransactionOption... options) {
return new DelayedTransactionRunner(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.api.core.ApiFuture;
import com.google.api.core.ApiFutures;
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.rpc.ServerStream;
import com.google.cloud.Timestamp;
import com.google.cloud.spanner.Options.TransactionOption;
import com.google.cloud.spanner.Options.UpdateOption;
Expand All @@ -30,6 +31,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.spanner.v1.BatchWriteResponse;
import com.google.spanner.v1.BeginTransactionRequest;
import com.google.spanner.v1.RequestOptions;
import com.google.spanner.v1.Transaction;
Expand Down Expand Up @@ -505,6 +507,14 @@ public CommitResponse writeAtLeastOnceWithOptions(
.writeAtLeastOnceWithOptions(mutations, options);
}

@Override
public ServerStream<BatchWriteResponse> batchWriteAtLeastOnce(
Iterable<MutationGroup> mutationGroups, TransactionOption... options)
throws SpannerException {
return createMultiplexedSessionTransaction(/* singleUse = */ true)
.batchWriteAtLeastOnce(mutationGroups, options);
}

@Override
public ReadContext singleUse() {
return createMultiplexedSessionTransaction(/* singleUse = */ true).singleUse();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ public ServerStream<BatchWriteResponse> batchWriteAtLeastOnce(
throw SpannerExceptionFactory.newSpannerException(e);
} finally {
span.end();
onTransactionDone();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import com.google.api.core.ApiFuture;
import com.google.api.core.ApiFutures;
import com.google.api.gax.rpc.ServerStream;
import com.google.cloud.NoCredentials;
import com.google.cloud.Timestamp;
import com.google.cloud.spanner.AsyncTransactionManager.AsyncTransactionStep;
Expand All @@ -45,6 +46,8 @@
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import com.google.spanner.v1.BatchWriteRequest;
import com.google.spanner.v1.BatchWriteResponse;
import com.google.spanner.v1.BeginTransactionRequest;
import com.google.spanner.v1.CommitRequest;
import com.google.spanner.v1.ExecuteSqlRequest;
Expand Down Expand Up @@ -1635,6 +1638,44 @@ public void testReadWriteTransactionWithCommitRetryProtocolExtensionSet() {
assertEquals(1L, client.multiplexedSessionDatabaseClient.getNumSessionsReleased().get());
}

@Test
public void testBatchWriteAtLeastOnce() {
DatabaseClientImpl client =
(DatabaseClientImpl) spanner.getDatabaseClient(DatabaseId.of("p", "i", "d"));

Iterable<MutationGroup> MUTATION_GROUPS =
ImmutableList.of(
MutationGroup.of(
Mutation.newInsertBuilder("FOO1").set("ID").to(1L).set("NAME").to("Bar1").build(),
Mutation.newInsertBuilder("FOO2").set("ID").to(2L).set("NAME").to("Bar2").build()),
MutationGroup.of(
Mutation.newInsertBuilder("FOO3").set("ID").to(3L).set("NAME").to("Bar3").build(),
Mutation.newInsertBuilder("FOO4").set("ID").to(4L).set("NAME").to("Bar4").build()));

ServerStream<BatchWriteResponse> responseStream = client.batchWriteAtLeastOnce(MUTATION_GROUPS);
int idx = 0;
for (BatchWriteResponse response : responseStream) {
assertEquals(
response.getStatus(),
com.google.rpc.Status.newBuilder().setCode(com.google.rpc.Code.OK_VALUE).build());
assertEquals(response.getIndexesList(), ImmutableList.of(idx, idx + 1));
idx += 2;
}

assertNotNull(responseStream);
List<BatchWriteRequest> requests = mockSpanner.getRequestsOfType(BatchWriteRequest.class);
assertEquals(requests.size(), 1);
BatchWriteRequest request = requests.get(0);
assertTrue(mockSpanner.getSession(request.getSession()).getMultiplexed());
assertEquals(request.getMutationGroupsCount(), 2);
assertEquals(request.getRequestOptions().getPriority(), Priority.PRIORITY_UNSPECIFIED);
assertFalse(request.getExcludeTxnFromChangeStreams());

assertNotNull(client.multiplexedSessionDatabaseClient);
assertEquals(1L, client.multiplexedSessionDatabaseClient.getNumSessionsAcquired().get());
assertEquals(1L, client.multiplexedSessionDatabaseClient.getNumSessionsReleased().get());
}

private void waitForSessionToBeReplaced(DatabaseClientImpl client) {
assertNotNull(client.multiplexedSessionDatabaseClient);
SessionReference sessionReference =
Expand Down

0 comments on commit 9a5d86b

Please sign in to comment.