diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java index 72a9e17639..ed8caf6c75 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java @@ -62,6 +62,7 @@ import org.apache.hadoop.io.compress.CompressionCodec; import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; import org.apache.tez.common.CallableWithNdc; +import org.apache.tez.common.IdUtils; import org.apache.tez.common.InputContextUtils; import org.apache.tez.common.RssTezConfig; import org.apache.tez.common.RssTezUtils; @@ -74,6 +75,7 @@ import org.apache.tez.dag.api.TezConfiguration; import org.apache.tez.dag.api.TezException; import org.apache.tez.dag.records.TezTaskAttemptID; +import org.apache.tez.dag.records.TezTaskID; import org.apache.tez.http.HttpConnectionParams; import org.apache.tez.runtime.api.Event; import org.apache.tez.runtime.api.InputContext; @@ -274,6 +276,7 @@ enum ShuffleErrors { private final Map> partitionIdToSuccessMapTaskAttempts = new HashMap<>(); + final Map> partitionIdToSuccessTezTasks = new HashMap<>(); private final String storageType; private final int readBufferSize; @@ -1143,8 +1146,13 @@ private boolean allEventsReceived() { } } + private boolean allInputTaskAttemptDone() { + return this.partitionIdToSuccessTezTasks.values().stream().mapToInt(s -> s.size()).sum() + == numInputs; + } + private boolean isAllInputFetched() { - return allEventsReceived() && (successRssPartitionSet.size() >= allRssPartition.size()); + return allInputTaskAttemptDone() && (successRssPartitionSet.size() >= allRssPartition.size()); } /** @@ -1293,6 +1301,10 @@ public synchronized void addKnownMapOutput( partitionIdToSuccessMapTaskAttempts.put(partitionId, new HashSet<>()); } partitionIdToSuccessMapTaskAttempts.get(partitionId).add(srcAttempt); + String pathComponent = srcAttempt.getPathComponent(); + TezTaskAttemptID tezTaskAttemptId = IdUtils.convertTezTaskAttemptID(pathComponent); + partitionIdToSuccessTezTasks.putIfAbsent(partitionId, new HashSet<>()); + partitionIdToSuccessTezTasks.get(partitionId).add(tezTaskAttemptId.getTaskID()); uniqueHosts.add(new HostPort(inputHostName, port)); HostPortPartition identifier = new HostPortPartition(inputHostName, port, partitionId); @@ -1661,10 +1673,10 @@ private class RssShuffleSchedulerCallable extends CallableWithNdc { protected Void callInternal() throws IOException, InterruptedException, TezException, RssException { while (!isShutdown.get() && !isAllInputFetched()) { - LOG.info("Now allEventsReceived: " + allEventsReceived()); + LOG.info("Now allInputTaskAttemptDone: " + allInputTaskAttemptDone()); synchronized (RssShuffleScheduler.this) { - while (!allEventsReceived() + while (!allInputTaskAttemptDone() || ((rssRunningFetchers.size() >= numFetchers || pendingHosts.isEmpty()) && !isAllInputFetched())) { try { diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleSchedulerTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleSchedulerTest.java index e9f665d226..18fdcf91d4 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleSchedulerTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleSchedulerTest.java @@ -38,6 +38,10 @@ import org.apache.tez.common.security.JobTokenIdentifier; import org.apache.tez.common.security.JobTokenSecretManager; import org.apache.tez.dag.api.TezConfiguration; +import org.apache.tez.dag.records.TezDAGID; +import org.apache.tez.dag.records.TezTaskAttemptID; +import org.apache.tez.dag.records.TezTaskID; +import org.apache.tez.dag.records.TezVertexID; import org.apache.tez.runtime.api.ExecutionContext; import org.apache.tez.runtime.api.InputContext; import org.apache.tez.runtime.api.impl.ExecutionContextImpl; @@ -54,6 +58,7 @@ import org.mockito.stubbing.Answer; import static org.apache.tez.runtime.library.common.shuffle.impl.RssShuffleManagerTest.APPATTEMPT_ID; +import static org.apache.tez.runtime.library.common.shuffle.impl.RssShuffleManagerTest.APP_ID; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -69,6 +74,10 @@ public class RssShuffleSchedulerTest { + private static final TezTaskAttemptID TEZ_TASK_ATTEMPT_ID = + TezTaskAttemptID.getInstance( + TezTaskID.getInstance(TezVertexID.getInstance(TezDAGID.getInstance(APP_ID, 0), 0), 0), 0); + private TezExecutors sharedExecutor; @BeforeEach @@ -108,7 +117,8 @@ public void testReducerHealth1(Configuration conf) throws IOException { // Generate 320 events for (int i = 0; i < 320; i++) { CompositeInputAttemptIdentifier inputAttemptIdentifier = - new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1); + new CompositeInputAttemptIdentifier( + i, 0, String.format("%s_%05d", TEZ_TASK_ATTEMPT_ID.toString(), 0), 1); scheduler.addKnownMapOutput( "host" + (i % totalProducerNodes), 10000, i, inputAttemptIdentifier); } @@ -186,7 +196,8 @@ public void testReducerHealth2() throws IOException, InterruptedException { // Generate 0-200 events for (int i = 0; i < 200; i++) { CompositeInputAttemptIdentifier inputAttemptIdentifier = - new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1); + new CompositeInputAttemptIdentifier( + i, 0, String.format("%s_%05d", TEZ_TASK_ATTEMPT_ID.toString(), 0), 1); scheduler.addKnownMapOutput( "host" + (i % totalProducerNodes), 10000, i, inputAttemptIdentifier); } @@ -345,7 +356,8 @@ public void testReducerHealth3() throws IOException { // Generate 320 events for (int i = 0; i < 320; i++) { CompositeInputAttemptIdentifier inputAttemptIdentifier = - new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1); + new CompositeInputAttemptIdentifier( + i, 0, String.format("%s_%05d", TEZ_TASK_ATTEMPT_ID.toString(), 0), 1); scheduler.addKnownMapOutput( "host" + (i % totalProducerNodes), 10000, i, inputAttemptIdentifier); } @@ -432,7 +444,8 @@ public void testReducerHealth4() throws IOException { // Generate 320 events for (int i = 0; i < 320; i++) { CompositeInputAttemptIdentifier inputAttemptIdentifier = - new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1); + new CompositeInputAttemptIdentifier( + i, 0, String.format("%s_%05d", TEZ_TASK_ATTEMPT_ID.toString(), 0), 1); scheduler.addKnownMapOutput( "host" + (i % totalProducerNodes), 10000, i, inputAttemptIdentifier); } @@ -573,7 +586,8 @@ public void testReducerHealth5() throws IOException { // Generate 319 events (last event has not arrived) for (int i = 0; i < 319; i++) { CompositeInputAttemptIdentifier inputAttemptIdentifier = - new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1); + new CompositeInputAttemptIdentifier( + i, 0, String.format("%s_%05d", TEZ_TASK_ATTEMPT_ID.toString(), 0), 1); scheduler.addKnownMapOutput( "host" + (i % totalProducerNodes), 10000, i, inputAttemptIdentifier); } @@ -669,7 +683,8 @@ public void testReducerHealth6(Configuration conf) throws IOException { // Generate 320 events (last event has not arrived) for (int i = 0; i < 320; i++) { CompositeInputAttemptIdentifier inputAttemptIdentifier = - new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1); + new CompositeInputAttemptIdentifier( + i, 0, String.format("%s_%05d", TEZ_TASK_ATTEMPT_ID.toString(), 0), 1); scheduler.addKnownMapOutput( "host" + (i % totalProducerNodes), 10000, i, inputAttemptIdentifier); } @@ -758,7 +773,8 @@ public void testReducerHealth7() throws IOException { // Generate 320 events for (int i = 0; i < 320; i++) { CompositeInputAttemptIdentifier inputAttemptIdentifier = - new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1); + new CompositeInputAttemptIdentifier( + i, 0, String.format("%s_%05d", TEZ_TASK_ATTEMPT_ID.toString(), 0), 1); scheduler.addKnownMapOutput( "host" + (i % totalProducerNodes), 10000, i, inputAttemptIdentifier); } @@ -854,7 +870,8 @@ public void testPenalty() throws IOException, InterruptedException { final ShuffleSchedulerForTest scheduler = createScheduler(startTime, 1, shuffle); CompositeInputAttemptIdentifier inputAttemptIdentifier = - new CompositeInputAttemptIdentifier(0, 0, "attempt_", 1); + new CompositeInputAttemptIdentifier( + 0, 0, String.format("%s_%05d", TEZ_TASK_ATTEMPT_ID.toString(), 0), 1); scheduler.addKnownMapOutput("host0", 10000, 0, inputAttemptIdentifier); assertTrue(scheduler.pendingHosts.size() == 1); @@ -950,7 +967,8 @@ public Void call() throws Exception { for (int i = 0; i < numInputs; i++) { CompositeInputAttemptIdentifier inputAttemptIdentifier = - new CompositeInputAttemptIdentifier(i, 0, "attempt_", 1); + new CompositeInputAttemptIdentifier( + i, 0, String.format("%s_%05d", TEZ_TASK_ATTEMPT_ID.toString(), 0), 1); scheduler.addKnownMapOutput("host" + i, 10000, 1, inputAttemptIdentifier); identifiers[i] = inputAttemptIdentifier; }