From 60bb0604217fbd44dbc6b115b59fc26a297fcaba Mon Sep 17 00:00:00 2001 From: "Doroszlai, Attila" <6454655+adoroszlai@users.noreply.github.com> Date: Thu, 23 Nov 2023 21:38:01 +0100 Subject: [PATCH] HDDS-9734. ChunkInputStream should use new token after pipeline refresh (#5664) --- .../hdds/scm/storage/BlockInputStream.java | 112 +++++++++------ .../hdds/scm/storage/ChunkInputStream.java | 131 +++++++++--------- .../scm/storage/DummyBlockInputStream.java | 2 +- .../DummyBlockInputStreamWithRetry.java | 4 +- .../scm/storage/DummyChunkInputStream.java | 2 +- .../scm/storage/TestBlockInputStream.java | 6 +- .../scm/storage/TestChunkInputStream.java | 62 ++++++--- 7 files changed, 183 insertions(+), 136 deletions(-) diff --git a/hadoop-hdds/client/src/main/java/org/apache/hadoop/hdds/scm/storage/BlockInputStream.java b/hadoop-hdds/client/src/main/java/org/apache/hadoop/hdds/scm/storage/BlockInputStream.java index c10e271f2e5..385ea6d0c3e 100644 --- a/hadoop-hdds/client/src/main/java/org/apache/hadoop/hdds/scm/storage/BlockInputStream.java +++ b/hadoop-hdds/client/src/main/java/org/apache/hadoop/hdds/scm/storage/BlockInputStream.java @@ -25,11 +25,11 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import com.google.common.base.Preconditions; import org.apache.hadoop.hdds.client.BlockID; -import org.apache.hadoop.hdds.client.ReplicationConfig; import org.apache.hadoop.hdds.client.StandaloneReplicationConfig; import org.apache.hadoop.hdds.protocol.datanode.proto.ContainerProtos.ContainerCommandResponseProto; import org.apache.hadoop.hdds.protocol.datanode.proto.ContainerProtos.ChunkInfo; @@ -52,6 +52,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.apache.hadoop.hdds.client.ReplicationConfig.getLegacyFactor; + /** * An {@link InputStream} called from KeyInputStream to read a block from the * container. @@ -65,8 +67,10 @@ public class BlockInputStream extends BlockExtendedInputStream { private final BlockID blockID; private final long length; - private Pipeline pipeline; - private Token token; + private final AtomicReference pipelineRef = + new AtomicReference<>(); + private final AtomicReference> tokenRef = + new AtomicReference<>(); private final boolean verifyChecksum; private XceiverClientFactory xceiverClientFactory; private XceiverClientSpi xceiverClient; @@ -113,8 +117,8 @@ public BlockInputStream(BlockID blockId, long blockLen, Pipeline pipeline, Function refreshFunction) { this.blockID = blockId; this.length = blockLen; - this.pipeline = pipeline; - this.token = token; + setPipeline(pipeline); + tokenRef.set(token); this.verifyChecksum = verifyChecksum; this.xceiverClientFactory = xceiverClientFactory; this.refreshFunction = refreshFunction; @@ -143,7 +147,7 @@ public synchronized void initialize() throws IOException { IOException catchEx = null; do { try { - chunks = getChunkInfos(); + chunks = getChunkInfoList(); break; // If we get a StorageContainerException or an IOException due to // datanodes are not reachable, refresh to get the latest pipeline @@ -203,7 +207,7 @@ private boolean isConnectivityIssue(IOException ex) { private void refreshBlockInfo(IOException cause) throws IOException { LOG.info("Unable to read information for block {} from pipeline {}: {}", - blockID, pipeline.getId(), cause.getMessage()); + blockID, pipelineRef.get().getId(), cause.getMessage()); if (refreshFunction != null) { LOG.debug("Re-fetching pipeline and block token for block {}", blockID); BlockLocationInfo blockLocationInfo = refreshFunction.apply(blockID); @@ -212,8 +216,8 @@ private void refreshBlockInfo(IOException cause) throws IOException { } else { LOG.debug("New pipeline for block {}: {}", blockID, blockLocationInfo.getPipeline()); - this.pipeline = blockLocationInfo.getPipeline(); - this.token = blockLocationInfo.getToken(); + setPipeline(blockLocationInfo.getPipeline()); + tokenRef.set(blockLocationInfo.getToken()); } } else { throw cause; @@ -224,46 +228,55 @@ private void refreshBlockInfo(IOException cause) throws IOException { * Send RPC call to get the block info from the container. * @return List of chunks in this block. */ - protected List getChunkInfos() throws IOException { - // irrespective of the container state, we will always read via Standalone - // protocol. - if (pipeline.getType() != HddsProtos.ReplicationType.STAND_ALONE && pipeline - .getType() != HddsProtos.ReplicationType.EC) { - pipeline = Pipeline.newBuilder(pipeline) - .setReplicationConfig(StandaloneReplicationConfig.getInstance( - ReplicationConfig - .getLegacyFactor(pipeline.getReplicationConfig()))) - .build(); - } + protected List getChunkInfoList() throws IOException { + acquireClient(); try { - acquireClient(); - } catch (IOException ioe) { - LOG.warn("Failed to acquire client for pipeline {}, block {}", - pipeline, blockID); - throw ioe; + return getChunkInfoListUsingClient(); + } finally { + releaseClient(); } - try { - if (LOG.isDebugEnabled()) { - LOG.debug("Initializing BlockInputStream for get key to access {}", - blockID.getContainerID()); - } + } - DatanodeBlockID.Builder blkIDBuilder = - DatanodeBlockID.newBuilder().setContainerID(blockID.getContainerID()) - .setLocalID(blockID.getLocalID()) - .setBlockCommitSequenceId(blockID.getBlockCommitSequenceId()); + @VisibleForTesting + protected List getChunkInfoListUsingClient() throws IOException { + final Pipeline pipeline = xceiverClient.getPipeline(); - int replicaIndex = pipeline.getReplicaIndex(pipeline.getClosestNode()); - if (replicaIndex > 0) { - blkIDBuilder.setReplicaIndex(replicaIndex); - } - GetBlockResponseProto response = ContainerProtocolCalls - .getBlock(xceiverClient, VALIDATORS, blkIDBuilder.build(), token); + if (LOG.isDebugEnabled()) { + LOG.debug("Initializing BlockInputStream for get key to access {}", + blockID.getContainerID()); + } - return response.getBlockData().getChunksList(); - } finally { - releaseClient(); + DatanodeBlockID.Builder blkIDBuilder = + DatanodeBlockID.newBuilder().setContainerID(blockID.getContainerID()) + .setLocalID(blockID.getLocalID()) + .setBlockCommitSequenceId(blockID.getBlockCommitSequenceId()); + + int replicaIndex = pipeline.getReplicaIndex(pipeline.getClosestNode()); + if (replicaIndex > 0) { + blkIDBuilder.setReplicaIndex(replicaIndex); } + + GetBlockResponseProto response = ContainerProtocolCalls.getBlock( + xceiverClient, VALIDATORS, blkIDBuilder.build(), tokenRef.get()); + + return response.getBlockData().getChunksList(); + } + + private void setPipeline(Pipeline pipeline) { + if (pipeline == null) { + return; + } + + // irrespective of the container state, we will always read via Standalone + // protocol. + boolean okForRead = + pipeline.getType() == HddsProtos.ReplicationType.STAND_ALONE + || pipeline.getType() == HddsProtos.ReplicationType.EC; + Pipeline readPipeline = okForRead ? pipeline : Pipeline.newBuilder(pipeline) + .setReplicationConfig(StandaloneReplicationConfig.getInstance( + getLegacyFactor(pipeline.getReplicationConfig()))) + .build(); + pipelineRef.set(readPipeline); } private static final List VALIDATORS @@ -286,9 +299,16 @@ private static void validate(ContainerCommandResponseProto response) } } - protected void acquireClient() throws IOException { + private void acquireClient() throws IOException { if (xceiverClientFactory != null && xceiverClient == null) { - xceiverClient = xceiverClientFactory.acquireClientForReadData(pipeline); + final Pipeline pipeline = pipelineRef.get(); + try { + xceiverClient = xceiverClientFactory.acquireClientForReadData(pipeline); + } catch (IOException ioe) { + LOG.warn("Failed to acquire client for pipeline {}, block {}", + pipeline, blockID); + throw ioe; + } } } @@ -303,7 +323,7 @@ protected synchronized void addStream(ChunkInfo chunkInfo) { protected ChunkInputStream createChunkInputStream(ChunkInfo chunkInfo) { return new ChunkInputStream(chunkInfo, blockID, - xceiverClientFactory, () -> pipeline, verifyChecksum, token); + xceiverClientFactory, pipelineRef::get, verifyChecksum, tokenRef::get); } @Override diff --git a/hadoop-hdds/client/src/main/java/org/apache/hadoop/hdds/scm/storage/ChunkInputStream.java b/hadoop-hdds/client/src/main/java/org/apache/hadoop/hdds/scm/storage/ChunkInputStream.java index fb9e3455862..b30f555795b 100644 --- a/hadoop-hdds/client/src/main/java/org/apache/hadoop/hdds/scm/storage/ChunkInputStream.java +++ b/hadoop-hdds/client/src/main/java/org/apache/hadoop/hdds/scm/storage/ChunkInputStream.java @@ -27,6 +27,8 @@ import org.apache.hadoop.fs.Seekable; import org.apache.hadoop.hdds.client.BlockID; import org.apache.hadoop.hdds.protocol.datanode.proto.ContainerProtos.ChunkInfo; +import org.apache.hadoop.hdds.protocol.datanode.proto.ContainerProtos.ContainerCommandRequestProto; +import org.apache.hadoop.hdds.protocol.datanode.proto.ContainerProtos.ContainerCommandResponseProto; import org.apache.hadoop.hdds.protocol.datanode.proto.ContainerProtos.ReadChunkResponseProto; import org.apache.hadoop.hdds.scm.XceiverClientFactory; import org.apache.hadoop.hdds.scm.XceiverClientSpi; @@ -37,7 +39,6 @@ import org.apache.hadoop.ozone.common.OzoneChecksumException; import org.apache.hadoop.ozone.common.utils.BufferUtils; import org.apache.hadoop.security.token.Token; -import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.ratis.thirdparty.com.google.protobuf.ByteString; import java.io.EOFException; @@ -48,9 +49,6 @@ import java.util.List; import java.util.function.Supplier; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - /** * An {@link InputStream} called from BlockInputStream to read a chunk from the * container. Each chunk may contain multiple underlying {@link ByteBuffer} @@ -59,16 +57,13 @@ public class ChunkInputStream extends InputStream implements Seekable, CanUnbuffer, ByteBufferReadable { - private static final Logger LOG = - LoggerFactory.getLogger(ChunkInputStream.class); - - private ChunkInfo chunkInfo; + private final ChunkInfo chunkInfo; private final long length; private final BlockID blockID; private final XceiverClientFactory xceiverClientFactory; private XceiverClientSpi xceiverClient; private final Supplier pipelineSupplier; - private boolean verifyChecksum; + private final boolean verifyChecksum; private boolean allocated = false; // Buffers to store the chunk data read from the DN container private ByteBuffer[] buffers; @@ -100,21 +95,24 @@ public class ChunkInputStream extends InputStream // retry. Once the chunk is read, this variable is reset. private long chunkPosition = -1; - private final Token token; + private final Supplier> tokenSupplier; private static final int EOF = -1; + private final List validators; ChunkInputStream(ChunkInfo chunkInfo, BlockID blockId, XceiverClientFactory xceiverClientFactory, Supplier pipelineSupplier, - boolean verifyChecksum, Token token) { + boolean verifyChecksum, + Supplier> tokenSupplier) { this.chunkInfo = chunkInfo; this.length = chunkInfo.getLen(); this.blockID = blockId; this.xceiverClientFactory = xceiverClientFactory; this.pipelineSupplier = pipelineSupplier; this.verifyChecksum = verifyChecksum; - this.token = token; + this.tokenSupplier = tokenSupplier; + validators = ContainerProtocolCalls.toValidatorList(this::validateChunk); } public synchronized long getRemaining() { @@ -422,13 +420,10 @@ private void readChunkDataIntoBuffers(ChunkInfo readChunkInfo) @VisibleForTesting protected ByteBuffer[] readChunk(ChunkInfo readChunkInfo) throws IOException { - ReadChunkResponseProto readChunkResponse; - List validators = - ContainerProtocolCalls.toValidatorList(validator); - - readChunkResponse = ContainerProtocolCalls.readChunk(xceiverClient, - readChunkInfo, blockID, validators, token); + ReadChunkResponseProto readChunkResponse = + ContainerProtocolCalls.readChunk(xceiverClient, + readChunkInfo, blockID, validators, tokenSupplier.get()); if (readChunkResponse.hasData()) { return readChunkResponse.getData().asReadOnlyByteBufferList() @@ -443,55 +438,57 @@ protected ByteBuffer[] readChunk(ChunkInfo readChunkInfo) } } - private final Validator validator = - (request, response) -> { - final ChunkInfo reqChunkInfo = - request.getReadChunk().getChunkData(); - - ReadChunkResponseProto readChunkResponse = response.getReadChunk(); - List byteStrings; - boolean isV0 = false; - - if (readChunkResponse.hasData()) { - ByteString byteString = readChunkResponse.getData(); - if (byteString.size() != reqChunkInfo.getLen()) { - // Bytes read from chunk should be equal to chunk size. - throw new OzoneChecksumException(String.format( - "Inconsistent read for chunk=%s len=%d bytesRead=%d", - reqChunkInfo.getChunkName(), reqChunkInfo.getLen(), - byteString.size())); - } - byteStrings = new ArrayList<>(); - byteStrings.add(byteString); - isV0 = true; - } else { - byteStrings = readChunkResponse.getDataBuffers().getBuffersList(); - long buffersLen = BufferUtils.getBuffersLen(byteStrings); - if (buffersLen != reqChunkInfo.getLen()) { - // Bytes read from chunk should be equal to chunk size. - throw new OzoneChecksumException(String.format( - "Inconsistent read for chunk=%s len=%d bytesRead=%d", - reqChunkInfo.getChunkName(), reqChunkInfo.getLen(), - buffersLen)); - } - } - - if (verifyChecksum) { - ChecksumData checksumData = ChecksumData.getFromProtoBuf( - chunkInfo.getChecksumData()); - - // ChecksumData stores checksum for each 'numBytesPerChecksum' - // number of bytes in a list. Compute the index of the first - // checksum to match with the read data - - long relativeOffset = reqChunkInfo.getOffset() - - chunkInfo.getOffset(); - int bytesPerChecksum = checksumData.getBytesPerChecksum(); - int startIndex = (int) (relativeOffset / bytesPerChecksum); - Checksum.verifyChecksum(byteStrings, checksumData, startIndex, - isV0); - } - }; + private void validateChunk( + ContainerCommandRequestProto request, + ContainerCommandResponseProto response + ) throws OzoneChecksumException { + final ChunkInfo reqChunkInfo = + request.getReadChunk().getChunkData(); + + ReadChunkResponseProto readChunkResponse = response.getReadChunk(); + List byteStrings; + boolean isV0 = false; + + if (readChunkResponse.hasData()) { + ByteString byteString = readChunkResponse.getData(); + if (byteString.size() != reqChunkInfo.getLen()) { + // Bytes read from chunk should be equal to chunk size. + throw new OzoneChecksumException(String.format( + "Inconsistent read for chunk=%s len=%d bytesRead=%d", + reqChunkInfo.getChunkName(), reqChunkInfo.getLen(), + byteString.size())); + } + byteStrings = new ArrayList<>(); + byteStrings.add(byteString); + isV0 = true; + } else { + byteStrings = readChunkResponse.getDataBuffers().getBuffersList(); + long buffersLen = BufferUtils.getBuffersLen(byteStrings); + if (buffersLen != reqChunkInfo.getLen()) { + // Bytes read from chunk should be equal to chunk size. + throw new OzoneChecksumException(String.format( + "Inconsistent read for chunk=%s len=%d bytesRead=%d", + reqChunkInfo.getChunkName(), reqChunkInfo.getLen(), + buffersLen)); + } + } + + if (verifyChecksum) { + ChecksumData checksumData = ChecksumData.getFromProtoBuf( + chunkInfo.getChecksumData()); + + // ChecksumData stores checksum for each 'numBytesPerChecksum' + // number of bytes in a list. Compute the index of the first + // checksum to match with the read data + + long relativeOffset = reqChunkInfo.getOffset() - + chunkInfo.getOffset(); + int bytesPerChecksum = checksumData.getBytesPerChecksum(); + int startIndex = (int) (relativeOffset / bytesPerChecksum); + Checksum.verifyChecksum(byteStrings, checksumData, startIndex, + isV0); + } + } /** * Return the offset and length of bytes that need to be read from the diff --git a/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyBlockInputStream.java b/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyBlockInputStream.java index be72dd07016..3e7779f0d10 100644 --- a/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyBlockInputStream.java +++ b/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyBlockInputStream.java @@ -57,7 +57,7 @@ class DummyBlockInputStream extends BlockInputStream { } @Override - protected List getChunkInfos() throws IOException { + protected List getChunkInfoList() throws IOException { return chunks; } diff --git a/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyBlockInputStreamWithRetry.java b/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyBlockInputStreamWithRetry.java index b39ed61d703..24a35745144 100644 --- a/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyBlockInputStreamWithRetry.java +++ b/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyBlockInputStreamWithRetry.java @@ -73,7 +73,7 @@ final class DummyBlockInputStreamWithRetry } @Override - protected List getChunkInfos() throws IOException { + protected List getChunkInfoList() throws IOException { if (getChunkInfoCount == 0) { getChunkInfoCount++; if (ioException != null) { @@ -82,7 +82,7 @@ protected List getChunkInfos() throws IOException { throw new StorageContainerException("Exception encountered", CONTAINER_NOT_FOUND); } else { - return super.getChunkInfos(); + return super.getChunkInfoList(); } } } diff --git a/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyChunkInputStream.java b/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyChunkInputStream.java index 78d0c05bfe0..25675607870 100644 --- a/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyChunkInputStream.java +++ b/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/DummyChunkInputStream.java @@ -45,7 +45,7 @@ public DummyChunkInputStream(ChunkInfo chunkInfo, boolean verifyChecksum, byte[] data, Pipeline pipeline) { super(chunkInfo, blockId, xceiverClientFactory, () -> pipeline, - verifyChecksum, null); + verifyChecksum, () -> null); this.chunkData = data.clone(); } diff --git a/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/TestBlockInputStream.java b/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/TestBlockInputStream.java index 6c518738cb7..2e95de1ecad 100644 --- a/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/TestBlockInputStream.java +++ b/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/TestBlockInputStream.java @@ -97,7 +97,8 @@ public void setup() throws Exception { checksum = new Checksum(ChecksumType.NONE, CHUNK_SIZE); createChunkList(5); - blockStream = new DummyBlockInputStream(blockID, blockSize, null, null, + Pipeline pipeline = MockPipeline.createSingleNodePipeline(); + blockStream = new DummyBlockInputStream(blockID, blockSize, pipeline, null, false, null, refreshFunction, chunks, chunkDataMap); } @@ -413,8 +414,7 @@ public void testRefreshOnReadFailureAfterUnbuffer(IOException ex) BlockInputStream subject = new BlockInputStream(blockID, blockSize, pipeline, null, false, clientFactory, refreshFunction) { @Override - protected List getChunkInfos() throws IOException { - acquireClient(); + protected List getChunkInfoListUsingClient() { return chunks; } diff --git a/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/TestChunkInputStream.java b/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/TestChunkInputStream.java index 3fe861402d7..f45529412fe 100644 --- a/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/TestChunkInputStream.java +++ b/hadoop-hdds/client/src/test/java/org/apache/hadoop/hdds/scm/storage/TestChunkInputStream.java @@ -22,25 +22,34 @@ import java.nio.ByteBuffer; import java.util.List; import java.util.Random; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import org.apache.hadoop.hdds.client.BlockID; import org.apache.hadoop.hdds.protocol.datanode.proto.ContainerProtos.ChecksumType; import org.apache.hadoop.hdds.protocol.datanode.proto.ContainerProtos.ChunkInfo; +import org.apache.hadoop.hdds.protocol.datanode.proto.ContainerProtos.ContainerCommandRequestProto; +import org.apache.hadoop.hdds.scm.ByteStringConversion; import org.apache.hadoop.hdds.scm.XceiverClientFactory; import org.apache.hadoop.hdds.scm.XceiverClientSpi; import org.apache.hadoop.hdds.scm.pipeline.MockPipeline; import org.apache.hadoop.hdds.scm.pipeline.Pipeline; import org.apache.hadoop.ozone.common.Checksum; +import org.apache.hadoop.ozone.common.ChunkBuffer; +import org.apache.hadoop.security.token.Token; import org.apache.ozone.test.GenericTestUtils; import org.apache.ratis.thirdparty.com.google.protobuf.ByteString; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.fail; +import static org.apache.hadoop.hdds.scm.protocolPB.ContainerCommandResponseBuilders.getReadChunkResponse; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -54,8 +63,10 @@ public class TestChunkInputStream { private static final int BYTES_PER_CHECKSUM = 20; private static final String CHUNK_NAME = "dummyChunk"; private static final Random RANDOM = new Random(); + private static final AtomicLong CONTAINER_ID = new AtomicLong(); private DummyChunkInputStream chunkStream; + private BlockID blockID; private ChunkInfo chunkInfo; private byte[] chunkData; @@ -65,6 +76,8 @@ public void setup() throws Exception { chunkData = generateRandomData(CHUNK_SIZE); + blockID = new BlockID(CONTAINER_ID.incrementAndGet(), 0); + chunkInfo = ChunkInfo.newBuilder() .setChunkName(CHUNK_NAME) .setOffset(0) @@ -73,7 +86,7 @@ public void setup() throws Exception { chunkData, 0, CHUNK_SIZE).getProtoBufMessage()) .build(); - chunkStream = new DummyChunkInputStream(chunkInfo, null, null, true, + chunkStream = new DummyChunkInputStream(chunkInfo, blockID, null, true, chunkData, null); } @@ -229,29 +242,46 @@ public void connectsToNewPipeline() throws Exception { // GIVEN Pipeline pipeline = MockPipeline.createSingleNodePipeline(); Pipeline newPipeline = MockPipeline.createSingleNodePipeline(); - XceiverClientFactory clientFactory = mock(XceiverClientFactory.class); - XceiverClientSpi client = mock(XceiverClientSpi.class); - when(clientFactory.acquireClientForReadData(pipeline)) - .thenReturn(client); + + Token token = mock(Token.class); + when(token.encodeToUrlString()) + .thenReturn("oldToken"); + Token newToken = mock(Token.class); + when(newToken.encodeToUrlString()) + .thenReturn("newToken"); AtomicReference pipelineRef = new AtomicReference<>(pipeline); + AtomicReference> tokenRef = new AtomicReference<>(token); - try (ChunkInputStream subject = new ChunkInputStream(chunkInfo, null, - clientFactory, pipelineRef::get, false, null) { - @Override - protected ByteBuffer[] readChunk(ChunkInfo readChunkInfo) { - return ByteString.copyFrom(chunkData).asReadOnlyByteBufferList() - .toArray(new ByteBuffer[0]); - } - }) { + XceiverClientFactory clientFactory = mock(XceiverClientFactory.class); + XceiverClientSpi client = mock(XceiverClientSpi.class); + when(clientFactory.acquireClientForReadData(any())) + .thenReturn(client); + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(ContainerCommandRequestProto.class); + when(client.getPipeline()) + .thenAnswer(invocation -> pipelineRef.get()); + when(client.sendCommand(requestCaptor.capture(), any())) + .thenAnswer(invocation -> + getReadChunkResponse( + requestCaptor.getValue(), + ChunkBuffer.wrap(ByteBuffer.wrap(chunkData)), + ByteStringConversion::safeWrap)); + + try (ChunkInputStream subject = new ChunkInputStream(chunkInfo, blockID, + clientFactory, pipelineRef::get, false, tokenRef::get)) { // WHEN subject.unbuffer(); pipelineRef.set(newPipeline); - int b = subject.read(); + tokenRef.set(newToken); + byte[] buffer = new byte[CHUNK_SIZE]; + int read = subject.read(buffer); // THEN - assertNotEquals(-1, b); + assertEquals(CHUNK_SIZE, read); + assertArrayEquals(chunkData, buffer); verify(clientFactory).acquireClientForReadData(newPipeline); + verify(newToken).encodeToUrlString(); } } }