diff --git a/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java b/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java index c95e91ae7b..16db7c7920 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java @@ -39,6 +39,8 @@ public class FileSegmentManagedBuffer extends ManagedBuffer { private final File file; private final long offset; private final int length; + private volatile boolean isFilled; + private ByteBuffer cachedBuffer; public FileSegmentManagedBuffer(File file, long offset, int length) { this.file = file; @@ -58,21 +60,25 @@ public ByteBuf byteBuf() { @Override public ByteBuffer nioByteBuffer() { + if (isFilled) { + return cachedBuffer; + } FileChannel channel = null; try { channel = new RandomAccessFile(file, "r").getChannel(); - ByteBuffer buf = ByteBuffer.allocate(length); + cachedBuffer = ByteBuffer.allocate(length); channel.position(offset); - while (buf.remaining() != 0) { - if (channel.read(buf) == -1) { + while (cachedBuffer.remaining() != 0) { + if (channel.read(cachedBuffer) == -1) { throw new IOException( String.format( "Reached EOF before filling buffer.offset=%s,file=%s,buf.remaining=%s", - offset, file.getAbsoluteFile(), buf.remaining())); + offset, file.getAbsoluteFile(), cachedBuffer.remaining())); } } - buf.flip(); - return buf; + cachedBuffer.flip(); + isFilled = true; + return cachedBuffer; } catch (IOException e) { String fileName = file.getAbsolutePath(); String errorMessage = @@ -102,6 +108,9 @@ public ManagedBuffer retain() { @Override public ManagedBuffer release() { + cachedBuffer.clear(); + cachedBuffer = null; + isFilled = false; return this; } diff --git a/common/src/test/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBufferTest.java b/common/src/test/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBufferTest.java new file mode 100644 index 0000000000..c8719ef2af --- /dev/null +++ b/common/src/test/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBufferTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.netty.buffer; + +import java.io.File; +import java.io.FileOutputStream; +import java.nio.ByteBuffer; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FileSegmentManagedBufferTest { + @Test + void testNioByteBuffer(@TempDir File tmpDir) { + File dataFile = new File(tmpDir, "data_file_1"); + String str = "Hello"; + byte[] strToBytes = str.getBytes(); + try (FileOutputStream outputStream = new FileOutputStream(dataFile)) { + outputStream.write(strToBytes); + } catch (Exception e) { + throw new RuntimeException(e); + } + FileSegmentManagedBuffer fileSegmentManagedBuffer = + new FileSegmentManagedBuffer(dataFile, 0, strToBytes.length); + ByteBuffer byteBuffer1 = fileSegmentManagedBuffer.nioByteBuffer(); + assertEquals(new String(byteBuffer1.array()), str); + + ByteBuffer byteBuffer2 = fileSegmentManagedBuffer.nioByteBuffer(); + assertTrue(byteBuffer1 == byteBuffer2); + fileSegmentManagedBuffer.release(); + + fileSegmentManagedBuffer = new FileSegmentManagedBuffer(dataFile, 0, strToBytes.length); + ByteBuffer byteBuffer3 = fileSegmentManagedBuffer.nioByteBuffer(); + assertFalse(byteBuffer3 == byteBuffer2); + assertFalse(byteBuffer3 == byteBuffer1); + + fileSegmentManagedBuffer.release(); + } +}