From 20dc4dca337d0a0c6d603586f52cde627ababe16 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Tue, 30 Jul 2024 01:58:01 +0200 Subject: [PATCH] Change FnApiDoFnRunner to skip trySplit checkpoint requests if not draining and nothing has yet been claimed by the tracker. --- .../beam/fn/harness/FnApiDoFnRunner.java | 40 ++- .../beam/fn/harness/FnApiDoFnRunnerTest.java | 303 ++++++++++++++++-- 2 files changed, 313 insertions(+), 30 deletions(-) diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index f85622ab89fe..7421a64eeba4 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -34,6 +34,7 @@ import java.util.Map; import java.util.NavigableSet; import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -327,6 +328,11 @@ static class Factory currentTracker; + /** + * If non-null, set to true after currentTracker has had a tryClaim issued on it. Used to ignore + * checkpoint split requests if no progress was made. + */ + private AtomicBoolean currentTrackerClaimed; /** * Only valid during {@link #processTimer} and {@link #processOnWindowExpiration}, null otherwise. @@ -877,12 +883,15 @@ private void processElementForSplitRestriction( currentElement = elem.withValue(elem.getValue().getKey()); currentRestriction = elem.getValue().getValue().getKey(); currentWatermarkEstimatorState = elem.getValue().getValue().getValue(); + currentTrackerClaimed = new AtomicBoolean(false); currentTracker = RestrictionTrackers.observe( doFnInvoker.invokeNewTracker(processContext), new ClaimObserver() { @Override - public void onClaimed(PositionT position) {} + public void onClaimed(PositionT position) { + currentTrackerClaimed.lazySet(true); + } @Override public void onClaimFailed(PositionT position) {} @@ -894,6 +903,7 @@ public void onClaimFailed(PositionT position) {} currentRestriction = null; currentWatermarkEstimatorState = null; currentTracker = null; + currentTrackerClaimed = null; } this.stateAccessor.finalizeState(); @@ -909,12 +919,15 @@ private void processElementForWindowObservingSplitRestriction( (Iterator) elem.getWindows().iterator(); while (windowIterator.hasNext()) { currentWindow = windowIterator.next(); + currentTrackerClaimed = new AtomicBoolean(false); currentTracker = RestrictionTrackers.observe( doFnInvoker.invokeNewTracker(processContext), new ClaimObserver() { @Override - public void onClaimed(PositionT position) {} + public void onClaimed(PositionT position) { + currentTrackerClaimed.lazySet(true); + } @Override public void onClaimFailed(PositionT position) {} @@ -927,6 +940,7 @@ public void onClaimFailed(PositionT position) {} currentWatermarkEstimatorState = null; currentWindow = null; currentTracker = null; + currentTrackerClaimed = null; } this.stateAccessor.finalizeState(); @@ -937,6 +951,8 @@ private void processElementForTruncateRestriction( currentElement = elem.withValue(elem.getValue().getKey().getKey()); currentRestriction = elem.getValue().getKey().getValue().getKey(); currentWatermarkEstimatorState = elem.getValue().getKey().getValue().getValue(); + // For truncation, we don't set currentTrackerClaimed so that we enable checkpointing even if no + // progress is made. currentTracker = RestrictionTrackers.observe( doFnInvoker.invokeNewTracker(processContext), @@ -989,6 +1005,8 @@ private void processElementForWindowObservingTruncateRestriction( currentRestriction = elem.getValue().getKey().getValue().getKey(); currentWatermarkEstimatorState = elem.getValue().getKey().getValue().getValue(); currentWindow = currentWindows.get(windowCurrentIndex); + // We leave currentTrackerClaimed unset as we want to split regardless of if tryClaim is + // called. currentTracker = RestrictionTrackers.observe( doFnInvoker.invokeNewTracker(processContext), @@ -1081,12 +1099,15 @@ private void processElementForWindowObservingSizedElementAndRestriction( currentRestriction = elem.getValue().getKey().getValue().getKey(); currentWatermarkEstimatorState = elem.getValue().getKey().getValue().getValue(); currentWindow = currentWindows.get(windowCurrentIndex); + currentTrackerClaimed = new AtomicBoolean(false); currentTracker = RestrictionTrackers.observe( doFnInvoker.invokeNewTracker(processContext), new ClaimObserver() { @Override - public void onClaimed(PositionT position) {} + public void onClaimed(PositionT position) { + currentTrackerClaimed.lazySet(true); + } @Override public void onClaimFailed(PositionT position) {} @@ -1278,6 +1299,13 @@ private HandlesSplits.SplitResult trySplitForWindowObservingTruncateRestriction( if (currentWindow == null) { return null; } + // We are requesting a checkpoint but have not yet progressed on the restriction, skip + // request. + if (fractionOfRemainder == 0 + && currentTrackerClaimed != null + && !currentTrackerClaimed.get()) { + return null; + } SplitResultsWithStopIndex splitResult = computeSplitForProcessOrTruncate( @@ -1628,6 +1656,12 @@ private HandlesSplits.SplitResult trySplitForElementAndRestriction( if (currentTracker == null) { return null; } + // The tracker has not yet been claimed meaning that a checkpoint won't meaningfully advance. + if (fractionOfRemainder == 0 + && currentTrackerClaimed != null + && !currentTrackerClaimed.get()) { + return null; + } // Make sure to get the output watermark before we split to ensure that the lower bound // applies to the residual. watermarkAndState = currentWatermarkEstimator.getWatermarkAndState(); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java index 11f25ab0116e..56682538cff7 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java @@ -151,6 +151,7 @@ import org.joda.time.Duration; import org.joda.time.Instant; import org.joda.time.format.PeriodFormat; +import org.junit.Assert; import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; @@ -1370,41 +1371,74 @@ public void testRegistration() { *
  • splitting thread: {@link * NonWindowObservingTestSplittableDoFn#waitForSplitElementToBeProcessed()} *
  • process element thread: {@link - * NonWindowObservingTestSplittableDoFn#enableAndWaitForTrySplitToHappen()} + * NonWindowObservingTestSplittableDoFn#splitElementProcessed()} *
  • splitting thread: perform try split - *
  • splitting thread: {@link - * NonWindowObservingTestSplittableDoFn#releaseWaitingProcessElementThread()} + *
  • splitting thread: {@link NonWindowObservingTestSplittableDoFn#trySplitPerformed()} * + *
  • process element thread: {@link + * NonWindowObservingTestSplittableDoFn#waitForTrySplitPerformed()} * */ static class NonWindowObservingTestSplittableDoFn extends DoFn { - private static final ConcurrentMap> - DOFN_INSTANCE_TO_LOCK = new ConcurrentHashMap<>(); + private static final ConcurrentMap DOFN_INSTANCE_TO_LATCHES = + new ConcurrentHashMap<>(); private static final long SPLIT_ELEMENT = 3; private static final long CHECKPOINT_UPPER_BOUND = 8; - private KV getLatches() { - return DOFN_INSTANCE_TO_LOCK.computeIfAbsent( - this.uuid, (uuid) -> KV.of(new CountDownLatch(1), new CountDownLatch(1))); + static class Latches { + public Latches() {} + + CountDownLatch blockProcessLatch = new CountDownLatch(0); + CountDownLatch processEnteredLatch = new CountDownLatch(1); + CountDownLatch splitElementProcessedLatch = new CountDownLatch(1); + CountDownLatch trySplitPerformedLatch = new CountDownLatch(1); + } + + private Latches getLatches() { + return DOFN_INSTANCE_TO_LATCHES.computeIfAbsent(this.uuid, (uuid) -> new Latches()); + } + + public void splitElementProcessed() { + getLatches().splitElementProcessedLatch.countDown(); } - public void enableAndWaitForTrySplitToHappen() throws Exception { - KV latches = getLatches(); - latches.getKey().countDown(); - if (!latches.getValue().await(30, TimeUnit.SECONDS)) { + public void waitForSplitElementToBeProcessed() throws InterruptedException { + if (!getLatches().splitElementProcessedLatch.await(30, TimeUnit.SECONDS)) { fail("Failed to wait for trySplit to occur."); } } - public void waitForSplitElementToBeProcessed() throws Exception { - KV latches = getLatches(); - if (!latches.getKey().await(30, TimeUnit.SECONDS)) { - fail("Failed to wait for split element to be processed."); + public void trySplitPerformed() { + getLatches().trySplitPerformedLatch.countDown(); + } + + public void waitForTrySplitPerformed() throws InterruptedException { + if (!getLatches().trySplitPerformedLatch.await(30, TimeUnit.SECONDS)) { + fail("Failed to wait for trySplit to occur."); + } + } + + // Must be called before process is invoked. Will prevent process from doing anything until + // unblockProcess is + // called. + public void setupBlockProcess() { + getLatches().blockProcessLatch = new CountDownLatch(1); + } + + public void enterProcessAndBlockIfEnabled() throws InterruptedException { + getLatches().processEnteredLatch.countDown(); + if (!getLatches().blockProcessLatch.await(30, TimeUnit.SECONDS)) { + fail("Failed to wait for unblockProcess to occur."); + } + } + + public void waitForProcessEntered() throws InterruptedException { + if (!getLatches().processEnteredLatch.await(5, TimeUnit.SECONDS)) { + fail("Failed to wait for process to begin."); } } - public void releaseWaitingProcessElementThread() { - KV latches = getLatches(); - latches.getValue().countDown(); + public void unblockProcess() throws InterruptedException { + getLatches().blockProcessLatch.countDown(); } private final String uuid; @@ -1427,7 +1461,8 @@ public ProcessContinuation processElement( if (!claimStatus) { break; } else if (position == SPLIT_ELEMENT) { - enableAndWaitForTrySplitToHappen(); + splitElementProcessed(); + waitForTrySplitPerformed(); } context.outputWithTimestamp( context.element() + ":" + position, @@ -1511,6 +1546,7 @@ public ProcessContinuation processElement( RestrictionTracker tracker, ManualWatermarkEstimator watermarkEstimator) throws Exception { + enterProcessAndBlockIfEnabled(); long checkpointUpperBound = Long.parseLong(context.sideInput(singletonSideInput)); long position = tracker.currentRestriction().getFrom(); boolean claimStatus; @@ -1519,7 +1555,8 @@ public ProcessContinuation processElement( if (!claimStatus) { break; } else if (position == NonWindowObservingTestSplittableDoFn.SPLIT_ELEMENT) { - enableAndWaitForTrySplitToHappen(); + splitElementProcessed(); + waitForTrySplitPerformed(); } context.outputWithTimestamp( context.element() + ":" + position, @@ -1549,7 +1586,8 @@ public TruncateResult truncateRestriction(@Restriction OffsetRange throws Exception { // Waiting for split when we are on the second window. if (splitAtTruncate && processedWindowCount == PROCESSED_WINDOW) { - enableAndWaitForTrySplitToHappen(); + splitElementProcessed(); + waitForTrySplitPerformed(); } processedWindowCount += 1; return TruncateResult.of(new OffsetRange(range.getFrom(), range.getTo() / 2)); @@ -1755,7 +1793,217 @@ public void testProcessElementForSizedElementAndRestriction() throws Exception { return ((HandlesSplits) mainInput).trySplit(0); } finally { - doFn.releaseWaitingProcessElementThread(); + doFn.trySplitPerformed(); + } + }); + + // Check that before processing an element we don't report progress + assertNoReportedProgress(context.getBundleProgressReporters()); + mainInput.accept( + valueInGlobalWindow( + KV.of( + KV.of("7", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), + 2.0))); + HandlesSplits.SplitResult trySplitResult = trySplitFuture.get(); + + // Check that after processing an element we don't report progress + assertNoReportedProgress(context.getBundleProgressReporters()); + + // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be split. + // We expect that the watermark advances to MIN + 2 since the manual watermark estimator + // has yet to be invoked for the split element and that the primary represents [0, 4) with + // the original watermark while the residual represents [4, 5) with the new MIN + 2 + // watermark. + assertThat( + mainOutputValues, + contains( + timestampedValueInGlobalWindow( + "7:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(0))), + timestampedValueInGlobalWindow( + "7:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(1))), + timestampedValueInGlobalWindow( + "7:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(2))), + timestampedValueInGlobalWindow( + "7:3", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(3))))); + + BundleApplication primaryRoot = Iterables.getOnlyElement(trySplitResult.getPrimaryRoots()); + DelayedBundleApplication residualRoot = + Iterables.getOnlyElement(trySplitResult.getResidualRoots()); + assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId()); + assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId()); + assertEquals( + ParDoTranslation.getMainInputName(pTransform), + residualRoot.getApplication().getInputId()); + assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId()); + assertEquals( + valueInGlobalWindow( + KV.of( + KV.of("7", KV.of(new OffsetRange(0, 4), GlobalWindow.TIMESTAMP_MIN_VALUE)), + 4.0)), + inputCoder.decode(primaryRoot.getElement().newInput())); + assertEquals( + valueInGlobalWindow( + KV.of( + KV.of( + "7", + KV.of( + new OffsetRange(4, 5), + GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(2)))), + 1.0)), + inputCoder.decode(residualRoot.getApplication().getElement().newInput())); + Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(2)); + assertEquals( + ImmutableMap.of( + "output", + org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.Timestamp.newBuilder() + .setSeconds(expectedOutputWatermark.getMillis() / 1000) + .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000) + .build()), + residualRoot.getApplication().getOutputWatermarksMap()); + // We expect 0 resume delay. + assertEquals( + residualRoot.getRequestedTimeDelay().getDefaultInstanceForType(), + residualRoot.getRequestedTimeDelay()); + // We don't expect the outputs to goto the SDK initiated checkpointing listener. + assertTrue(splitListener.getPrimaryRoots().isEmpty()); + assertTrue(splitListener.getResidualRoots().isEmpty()); + mainOutputValues.clear(); + executorService.shutdown(); + } + + Iterables.getOnlyElement(context.getFinishBundleFunctions()).run(); + assertThat(mainOutputValues, empty()); + + Iterables.getOnlyElement(context.getTearDownFunctions()).run(); + assertThat(mainOutputValues, empty()); + + // Assert that state data did not change + assertEquals( + new FakeBeamFnStateClient(StringUtf8Coder.of(), stateData).getData(), + fakeClient.getData()); + } + + @Test + public void testProcessElementForSizedElementAndRestrictionSplitBeforeTryClaim() + throws Exception { + Pipeline p = Pipeline.create(); + addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api"); + // TODO(BEAM-10097): Remove experiment once all portable runners support this view type + addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2"); + PCollection valuePCollection = p.apply(Create.of("unused")); + PCollectionView singletonSideInputView = valuePCollection.apply(View.asSingleton()); + WindowObservingTestSplittableDoFn doFn = + new WindowObservingTestSplittableDoFn(singletonSideInputView); + valuePCollection.apply( + TEST_TRANSFORM_ID, ParDo.of(doFn).withSideInputs(singletonSideInputView)); + + RunnerApi.Pipeline pProto = + ProtoOverrides.updateTransform( + PTransformTranslation.PAR_DO_TRANSFORM_URN, + PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true), + SplittableParDoExpander.createSizedReplacement()); + String expandedTransformId = + Iterables.find( + pProto.getComponents().getTransformsMap().entrySet(), + entry -> + entry + .getValue() + .getSpec() + .getUrn() + .equals( + PTransformTranslation + .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN) + && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID)) + .getKey(); + RunnerApi.PTransform pTransform = + pProto.getComponents().getTransformsOrThrow(expandedTransformId); + String inputPCollectionId = + pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform)); + RunnerApi.PCollection inputPCollection = + pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId); + RehydratedComponents rehydratedComponents = + RehydratedComponents.forComponents(pProto.getComponents()); + Coder inputCoder = + WindowedValue.getFullCoder( + CoderTranslation.fromProto( + pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()), + rehydratedComponents, + TranslationContext.DEFAULT), + (Coder) + CoderTranslation.fromProto( + pProto + .getComponents() + .getCodersOrThrow( + pProto + .getComponents() + .getWindowingStrategiesOrThrow( + inputPCollection.getWindowingStrategyId()) + .getWindowCoderId()), + rehydratedComponents, + TranslationContext.DEFAULT)); + String outputPCollectionId = pTransform.getOutputsOrThrow("output"); + + ImmutableMap> stateData = + ImmutableMap.of( + iterableSideInputKey( + singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY), + asList("8")); + + FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(StringUtf8Coder.of(), stateData); + + BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create(); + + PTransformRunnerFactoryTestContext context = + PTransformRunnerFactoryTestContext.builder(TEST_TRANSFORM_ID, pTransform) + .beamFnStateClient(fakeClient) + .processBundleInstructionId("57") + .pCollections(pProto.getComponentsOrBuilder().getPcollectionsMap()) + .coders(pProto.getComponents().getCodersMap()) + .windowingStrategies(pProto.getComponents().getWindowingStrategiesMap()) + .splitListener(splitListener) + .build(); + List> mainOutputValues = new ArrayList<>(); + context.addPCollectionConsumer( + outputPCollectionId, + (FnDataReceiver) (FnDataReceiver>) mainOutputValues::add); + + new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(context); + + Iterables.getOnlyElement(context.getStartBundleFunctions()).run(); + mainOutputValues.clear(); + + assertThat( + context.getPCollectionConsumers().keySet(), + containsInAnyOrder(inputPCollectionId, outputPCollectionId)); + + FnDataReceiver> mainInput = + context.getPCollectionConsumer(inputPCollectionId); + assertThat(mainInput, instanceOf(HandlesSplits.class)); + + doFn.setupBlockProcess(); + { + // Setup and launch the trySplit thread. + ExecutorService executorService = Executors.newSingleThreadExecutor(); + Future trySplitFuture = + executorService.submit( + () -> { + try { + // Verify that a split before anything is claimed is ignored. + doFn.waitForProcessEntered(); + Assert.assertNull(((HandlesSplits) mainInput).trySplit(0)); + doFn.unblockProcess(); + + doFn.waitForSplitElementToBeProcessed(); + // Currently processing "3" out of range [0, 5) elements. + assertEquals(0.6, ((HandlesSplits) mainInput).getProgress(), 0.01); + + // Check that during progressing of an element we report progress + assertReportedProgressEquals( + context.getShortIdMap(), context.getBundleProgressReporters(), 3.0, 2.0); + + return ((HandlesSplits) mainInput).trySplit(0); + } finally { + doFn.trySplitPerformed(); } }); @@ -2187,7 +2435,7 @@ public void testProcessElementForWindowedSizedElementAndRestriction() throws Exc return ((HandlesSplits) mainInput).trySplit(0); } finally { - doFn.releaseWaitingProcessElementThread(); + doFn.trySplitPerformed(); } }); @@ -3143,10 +3391,11 @@ public void testProcessElementForTruncateAndSizeRestrictionForwardSplitWhenObser () -> { try { doFn.waitForSplitElementToBeProcessed(); - - return ((HandlesSplits) mainInput).trySplit(0); + HandlesSplits.SplitResult result = ((HandlesSplits) mainInput).trySplit(0); + Assert.assertNotNull(result); + return result; } finally { - doFn.releaseWaitingProcessElementThread(); + doFn.trySplitPerformed(); } });