diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java index db5dfcf6825d..93fe4ddbfaf6 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java @@ -25,6 +25,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; +import java.util.ServiceLoader; import java.util.UUID; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; @@ -32,8 +34,9 @@ import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.model.pipeline.v1.RunnerApi; -import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; import org.apache.beam.model.pipeline.v1.SchemaApi; +import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transformservice.launcher.TransformServiceLauncher; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; @@ -41,6 +44,7 @@ import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannelBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.checkerframework.checker.nullness.qual.Nullable; /** * A utility class that allows upgrading transforms of a given pipeline using the Beam Transform @@ -145,7 +149,7 @@ RunnerApi.Pipeline updateTransformViaTransformService( String transformId, Endpoints.ApiServiceDescriptor transformServiceEndpoint) throws IOException { - PTransform transformToUpgrade = + RunnerApi.PTransform transformToUpgrade = runnerAPIpipeline.getComponents().getTransformsMap().get(transformId); if (transformToUpgrade == null) { throw new IllegalArgumentException("Could not find a transform with the ID " + transformId); @@ -252,7 +256,7 @@ RunnerApi.Pipeline updateTransformViaTransformService( recursivelyFindSubTransforms( transformId, runnerAPIpipeline.getComponents(), transformsToRemove); - Map updatedExpandedTransformMap = + Map updatedExpandedTransformMap = expandedComponents.getTransformsMap().entrySet().stream() .filter( entry -> { @@ -265,7 +269,7 @@ RunnerApi.Pipeline updateTransformViaTransformService( entry -> { // Fix inputs Map inputsMap = entry.getValue().getInputsMap(); - PTransform.Builder transformBuilder = entry.getValue().toBuilder(); + RunnerApi.PTransform.Builder transformBuilder = entry.getValue().toBuilder(); if (!Collections.disjoint(inputsMap.values(), inputReplacements.keySet())) { Map updatedInputsMap = new HashMap<>(); for (Map.Entry inputEntry : inputsMap.entrySet()) { @@ -297,7 +301,7 @@ RunnerApi.Pipeline updateTransformViaTransformService( private static void recursivelyFindSubTransforms( String transformId, RunnerApi.Components components, List results) { results.add(transformId); - PTransform transform = components.getTransformsMap().get(transformId); + RunnerApi.PTransform transform = components.getTransformsMap().get(transformId); if (transform == null) { throw new IllegalArgumentException("Could not find a transform with id " + transformId); } @@ -328,4 +332,39 @@ private static int findAvailablePort() throws IOException { public void close() throws Exception { clientFactory.close(); } + + /** + * A utility to find the registered URN for a given transform. + * + *

This URN can be used to upgrade this transform to a new Beam version without upgrading the + * rest of the pipeline. Please see Beam + * Transform Service documentation for more details. + * + *

For this lookup to work, the a {@link TransformPayloadTranslatorRegistrar} for the transform + * has to be available in the classpath. + * + * @param transform transform to lookup. + * @return a URN if discovered. Returns {@code null} otherwise. + */ + @SuppressWarnings({ + "rawtypes", + "EqualsIncompatibleType", + }) + public static @Nullable String findUpgradeURN(PTransform transform) { + for (TransformPayloadTranslatorRegistrar registrar : + ServiceLoader.load(TransformPayloadTranslatorRegistrar.class)) { + + for (Entry< + ? extends Class, + ? extends TransformPayloadTranslator> + entry : registrar.getTransformPayloadTranslators().entrySet()) { + if (entry.getKey().equals(transform.getClass())) { + return entry.getValue().getUrn(); + } + } + } + + return null; + } } diff --git a/sdks/java/io/expansion-service/build.gradle b/sdks/java/io/expansion-service/build.gradle index 254bbf356c9b..9ab71ff16d3d 100644 --- a/sdks/java/io/expansion-service/build.gradle +++ b/sdks/java/io/expansion-service/build.gradle @@ -35,6 +35,8 @@ dependencies { permitUnusedDeclared project(":sdks:java:expansion-service") // BEAM-11761 implementation project(":sdks:java:io:kafka") permitUnusedDeclared project(":sdks:java:io:kafka") // BEAM-11761 + implementation project(":sdks:java:io:kafka:upgrade") + permitUnusedDeclared project(":sdks:java:io:kafka:upgrade") // BEAM-11761 runtimeOnly library.java.kafka_clients runtimeOnly library.java.slf4j_jdk14 } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 26f6c3448801..7e4fc55c6ce9 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -626,66 +626,70 @@ public static WriteRecords writeRecords() { @AutoValue.CopyAnnotations public abstract static class Read extends PTransform>> { + + public static final Class AUTOVALUE_CLASS = + AutoValue_KafkaIO_Read.class; + @Pure - abstract Map getConsumerConfig(); + public abstract Map getConsumerConfig(); @Pure - abstract @Nullable List getTopics(); + public abstract @Nullable List getTopics(); @Pure - abstract @Nullable List getTopicPartitions(); + public abstract @Nullable List getTopicPartitions(); @Pure - abstract @Nullable Pattern getTopicPattern(); + public abstract @Nullable Pattern getTopicPattern(); @Pure - abstract @Nullable Coder getKeyCoder(); + public abstract @Nullable Coder getKeyCoder(); @Pure - abstract @Nullable Coder getValueCoder(); + public abstract @Nullable Coder getValueCoder(); @Pure - abstract SerializableFunction, Consumer> + public abstract SerializableFunction, Consumer> getConsumerFactoryFn(); @Pure - abstract @Nullable SerializableFunction, Instant> getWatermarkFn(); + public abstract @Nullable SerializableFunction, Instant> getWatermarkFn(); @Pure - abstract long getMaxNumRecords(); + public abstract long getMaxNumRecords(); @Pure - abstract @Nullable Duration getMaxReadTime(); + public abstract @Nullable Duration getMaxReadTime(); @Pure - abstract @Nullable Instant getStartReadTime(); + public abstract @Nullable Instant getStartReadTime(); @Pure - abstract @Nullable Instant getStopReadTime(); + public abstract @Nullable Instant getStopReadTime(); @Pure - abstract boolean isCommitOffsetsInFinalizeEnabled(); + public abstract boolean isCommitOffsetsInFinalizeEnabled(); @Pure - abstract boolean isDynamicRead(); + public abstract boolean isDynamicRead(); @Pure - abstract @Nullable Duration getWatchTopicPartitionDuration(); + public abstract @Nullable Duration getWatchTopicPartitionDuration(); @Pure - abstract TimestampPolicyFactory getTimestampPolicyFactory(); + public abstract TimestampPolicyFactory getTimestampPolicyFactory(); @Pure - abstract @Nullable Map getOffsetConsumerConfig(); + public abstract @Nullable Map getOffsetConsumerConfig(); @Pure - abstract @Nullable DeserializerProvider getKeyDeserializerProvider(); + public abstract @Nullable DeserializerProvider getKeyDeserializerProvider(); @Pure - abstract @Nullable DeserializerProvider getValueDeserializerProvider(); + public abstract @Nullable DeserializerProvider getValueDeserializerProvider(); @Pure - abstract @Nullable CheckStopReadingFn getCheckStopReadingFn(); + public abstract @Nullable CheckStopReadingFn getCheckStopReadingFn(); abstract Builder toBuilder(); @@ -996,6 +1000,14 @@ public Read withKeyDeserializer(DeserializerProvider deserializerProvid return toBuilder().setKeyDeserializerProvider(deserializerProvider).build(); } + public Read withKeyDeserializerProviderAndCoder( + DeserializerProvider deserializerProvider, Coder keyCoder) { + return toBuilder() + .setKeyDeserializerProvider(deserializerProvider) + .setKeyCoder(keyCoder) + .build(); + } + /** * Sets a Kafka {@link Deserializer} to interpret value bytes read from Kafka. * @@ -1024,6 +1036,14 @@ public Read withValueDeserializer(DeserializerProvider deserializerProv return toBuilder().setValueDeserializerProvider(deserializerProvider).build(); } + public Read withValueDeserializerProviderAndCoder( + DeserializerProvider deserializerProvider, Coder valueCoder) { + return toBuilder() + .setValueDeserializerProvider(deserializerProvider) + .setValueCoder(valueCoder) + .build(); + } + /** * A factory to create Kafka {@link Consumer} from consumer configuration. This is useful for * supporting another version of Kafka consumer. Default is {@link KafkaConsumer}. @@ -2485,37 +2505,37 @@ public abstract static class WriteRecords // {@link WriteRecords}. See example at {@link PubsubIO.Write}. @Pure - abstract @Nullable String getTopic(); + public abstract @Nullable String getTopic(); @Pure - abstract Map getProducerConfig(); + public abstract Map getProducerConfig(); @Pure - abstract @Nullable SerializableFunction, Producer> + public abstract @Nullable SerializableFunction, Producer> getProducerFactoryFn(); @Pure - abstract @Nullable Class> getKeySerializer(); + public abstract @Nullable Class> getKeySerializer(); @Pure - abstract @Nullable Class> getValueSerializer(); + public abstract @Nullable Class> getValueSerializer(); @Pure - abstract @Nullable KafkaPublishTimestampFunction> + public abstract @Nullable KafkaPublishTimestampFunction> getPublishTimestampFunction(); // Configuration for EOS sink @Pure - abstract boolean isEOS(); + public abstract boolean isEOS(); @Pure - abstract @Nullable String getSinkGroupId(); + public abstract @Nullable String getSinkGroupId(); @Pure - abstract int getNumShards(); + public abstract int getNumShards(); @Pure - abstract @Nullable SerializableFunction, ? extends Consumer> + public abstract @Nullable SerializableFunction, ? extends Consumer> getConsumerFactoryFn(); abstract Builder toBuilder(); @@ -2777,9 +2797,12 @@ public abstract static class Write extends PTransform // we shouldn't have to duplicate the same API for similar transforms like {@link Write} and // {@link WriteRecords}. See example at {@link PubsubIO.Write}. + public static final Class AUTOVALUE_CLASS = + AutoValue_KafkaIO_Write.class; + abstract @Nullable String getTopic(); - abstract WriteRecords getWriteRecordsTransform(); + public abstract WriteRecords getWriteRecordsTransform(); abstract Builder toBuilder(); diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java index ad17369715dd..748418d16664 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java @@ -34,7 +34,7 @@ * Common utility functions and default configurations for {@link KafkaIO.Read} and {@link * KafkaIO.ReadSourceDescriptors}. */ -final class KafkaIOUtils { +public final class KafkaIOUtils { // A set of config defaults. static final Map DEFAULT_CONSUMER_PROPERTIES = ImmutableMap.of( @@ -61,7 +61,7 @@ final class KafkaIOUtils { false); // A set of properties that are not required or don't make sense for our consumer. - static final Map DISALLOWED_CONSUMER_PROPERTIES = + public static final Map DISALLOWED_CONSUMER_PROPERTIES = ImmutableMap.of( ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "Set keyDeserializer instead", ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "Set valueDeserializer instead" diff --git a/sdks/java/io/kafka/upgrade/build.gradle b/sdks/java/io/kafka/upgrade/build.gradle new file mode 100644 index 000000000000..78776e6e8264 --- /dev/null +++ b/sdks/java/io/kafka/upgrade/build.gradle @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.stream.Collectors + +plugins { id 'org.apache.beam.module' } +applyJavaNature( + automaticModuleName: 'org.apache.beam.sdk.io.kafka.upgrade', +) + +description = "Apache Beam :: SDKs :: Java :: IO :: Kafka :: Upgrade" +ext.summary = "Library to support upgrading Kafka transforms without upgrading the pipeline." + +dependencies { + implementation project(path: ":model:pipeline", configuration: "shadow") + implementation project(":sdks:java:io:kafka") + implementation library.java.vendored_grpc_1_54_0 + implementation project(path: ":sdks:java:core", configuration: "shadow") + + implementation library.java.vendored_guava_32_1_2_jre + implementation project(":runners:core-construction-java") + implementation project(":sdks:java:expansion-service") + permitUnusedDeclared project(":sdks:java:expansion-service") // BEAM-11761 + permitUnusedDeclared project(":model:pipeline") + implementation library.java.joda_time + // Get back to "provided" since 2.14 + provided library.java.kafka_clients + testImplementation library.java.junit + testImplementation library.java.kafka_clients + +} diff --git a/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java b/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java new file mode 100644 index 000000000000..1757da6a6311 --- /dev/null +++ b/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java @@ -0,0 +1,578 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka.upgrade; + +import com.google.auto.service.AutoService; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; +import org.apache.beam.runners.core.construction.SdkComponents; +import org.apache.beam.runners.core.construction.TransformPayloadTranslatorRegistrar; +import org.apache.beam.sdk.io.kafka.DeserializerProvider; +import org.apache.beam.sdk.io.kafka.KafkaIO; +import org.apache.beam.sdk.io.kafka.KafkaIO.Read; +import org.apache.beam.sdk.io.kafka.KafkaIO.Write; +import org.apache.beam.sdk.io.kafka.KafkaIO.WriteRecords; +import org.apache.beam.sdk.io.kafka.KafkaIOUtils; +import org.apache.beam.sdk.io.kafka.TimestampPolicyFactory; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.logicaltypes.NanosDuration; +import org.apache.beam.sdk.schemas.logicaltypes.NanosInstant; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.TopicPartition; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; + +/** + * Utility methods for translating {@link KafkaIO} transforms to and from {@link RunnerApi} + * representations. + */ +public class KafkaIOTranslation { + + // We define new v2 URNs here for KafkaIO transforms that includes all properties of Java + // transforms. Kafka read/write v1 URNs are defined in KafkaIO.java and offer a limited set of + // properties adjusted to cross-language usage with portable types. + public static final String KAFKA_READ_WITH_METADATA_TRANSFORM_URN_V2 = + "beam:transform:org.apache.beam:kafka_read_with_metadata:v2"; + public static final String KAFKA_WRITE_TRANSFORM_URN_V2 = + "beam:transform:org.apache.beam:kafka_write:v2"; + + private static byte[] toByteArray(Object object) { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(bos)) { + out.writeObject(object); + return bos.toByteArray(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static Object fromByteArray(byte[] bytes) { + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + ObjectInputStream in = new ObjectInputStream(bis)) { + return in.readObject(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + static class KafkaIOReadWithMetadataTranslator implements TransformPayloadTranslator> { + + static Schema topicPartitionSchema = + Schema.builder().addStringField("topic").addInt32Field("partition").build(); + + static Schema schema = + Schema.builder() + .addMapField("consumer_config", FieldType.STRING, FieldType.BYTES) + .addNullableArrayField("topics", FieldType.STRING) + .addNullableArrayField("topic_partitions", FieldType.row(topicPartitionSchema)) + .addNullableStringField("topic_pattern") + .addNullableByteArrayField("key_coder") + .addNullableByteArrayField("value_coder") + .addByteArrayField("consumer_factory_fn") + .addNullableByteArrayField("watermark_fn") + .addInt64Field("max_num_records") + .addNullableLogicalTypeField("max_read_time", new NanosDuration()) + .addNullableLogicalTypeField("start_read_time", new NanosInstant()) + .addNullableLogicalTypeField("stop_read_time", new NanosInstant()) + .addBooleanField("is_commit_offset_finalize_enabled") + .addBooleanField("is_dynamic_read") + .addNullableLogicalTypeField("watch_topic_partition_duration", new NanosDuration()) + .addByteArrayField("timestamp_policy_factory") + .addNullableMapField("offset_consumer_config", FieldType.STRING, FieldType.BYTES) + .addNullableByteArrayField("key_deserializer_provider") + .addNullableByteArrayField("value_deserializer_provider") + .addNullableByteArrayField("check_stop_reading_fn") + .build(); + + @Override + public String getUrn() { + return KAFKA_READ_WITH_METADATA_TRANSFORM_URN_V2; + } + + @Override + @SuppressWarnings({ + "type.argument", + }) + public RunnerApi.@Nullable FunctionSpec translate( + AppliedPTransform> application, SdkComponents components) + throws IOException { + // Setting an empty payload since Kafka transform payload is not actually used by runners + // currently. + // This can be implemented if runners started actually using the Kafka transform payload. + return FunctionSpec.newBuilder().setUrn(getUrn()).setPayload(ByteString.empty()).build(); + } + + @Override + public Row toConfigRow(Read transform) { + Map fieldValues = new HashMap<>(); + Map consumerConfigMap = new HashMap<>(); + transform + .getConsumerConfig() + .forEach( + (key, val) -> { + consumerConfigMap.put(key, toByteArray(val)); + }); + fieldValues.put("consumer_config", consumerConfigMap); + if (transform.getTopics() != null) { + fieldValues.put("topics", transform.getTopics()); + } + + if (transform.getTopicPartitions() != null) { + List encodedTopicPartitions = new ArrayList<>(); + for (TopicPartition topicPartition : transform.getTopicPartitions()) { + encodedTopicPartitions.add( + Row.withSchema(topicPartitionSchema) + .addValue(topicPartition.topic()) + .addValue(topicPartition.partition()) + .build()); + } + fieldValues.put("topic_partitions", encodedTopicPartitions); + } + if (transform.getTopicPattern() != null) { + fieldValues.put("topic_pattern", transform.getTopicPattern().pattern()); + } + if (transform.getKeyCoder() != null) { + fieldValues.put("key_coder", toByteArray(transform.getKeyCoder())); + } + if (transform.getValueCoder() != null) { + fieldValues.put("value_coder", toByteArray(transform.getValueCoder())); + } + if (transform.getConsumerFactoryFn() != null) { + fieldValues.put("consumer_factory_fn", toByteArray(transform.getConsumerFactoryFn())); + } + if (transform.getWatermarkFn() != null) { + fieldValues.put("watermark_fn", toByteArray(transform.getWatermarkFn())); + } + fieldValues.put("max_num_records", transform.getMaxNumRecords()); + if (transform.getMaxReadTime() != null) { + fieldValues.put("max_read_time", transform.getMaxReadTime()); + } + if (transform.getStartReadTime() != null) { + fieldValues.put("start_read_time", transform.getStartReadTime()); + } + if (transform.getStopReadTime() != null) { + fieldValues.put("stop_read_time", transform.getStopReadTime()); + } + + fieldValues.put( + "is_commit_offset_finalize_enabled", transform.isCommitOffsetsInFinalizeEnabled()); + fieldValues.put("is_dynamic_read", transform.isDynamicRead()); + if (transform.getWatchTopicPartitionDuration() != null) { + fieldValues.put( + "watch_topic_partition_duration", transform.getWatchTopicPartitionDuration()); + } + + fieldValues.put( + "timestamp_policy_factory", toByteArray(transform.getTimestampPolicyFactory())); + + if (transform.getOffsetConsumerConfig() != null) { + Map offsetConsumerConfigMap = new HashMap<>(); + transform + .getOffsetConsumerConfig() + .forEach( + (key, val) -> { + offsetConsumerConfigMap.put(key, toByteArray(val)); + }); + fieldValues.put("offset_consumer_config", offsetConsumerConfigMap); + } + if (transform.getKeyDeserializerProvider() != null) { + fieldValues.put( + "key_deserializer_provider", toByteArray(transform.getKeyDeserializerProvider())); + } + if (transform.getValueDeserializerProvider() != null) { + fieldValues.put( + "value_deserializer_provider", toByteArray(transform.getValueDeserializerProvider())); + } + if (transform.getCheckStopReadingFn() != null) { + fieldValues.put("check_stop_reading_fn", toByteArray(transform.getCheckStopReadingFn())); + } + + return Row.withSchema(schema).withFieldValues(fieldValues).build(); + } + + @Override + public Read fromConfigRow(Row configRow) { + Read transform = KafkaIO.read(); + + Map consumerConfig = configRow.getMap("consumer_config"); + if (consumerConfig != null) { + Map updatedConsumerConfig = new HashMap<>(); + consumerConfig.forEach( + (key, dataBytes) -> { + // Adding all allowed properties. + if (!KafkaIOUtils.DISALLOWED_CONSUMER_PROPERTIES.containsKey(key)) { + if (consumerConfig.get(key) == null) { + throw new IllegalArgumentException( + "Encoded value of the consumer config property " + key + " was null"); + } + updatedConsumerConfig.put(key, fromByteArray(consumerConfig.get(key))); + } + }); + transform = transform.withConsumerConfigUpdates(updatedConsumerConfig); + } + Collection topics = configRow.getArray("topics"); + if (topics != null) { + transform = transform.withTopics(new ArrayList<>(topics)); + } + Collection topicPartitionRows = configRow.getArray("topic_partitions"); + if (topicPartitionRows != null) { + Collection topicPartitions = + topicPartitionRows.stream() + .map( + row -> { + String topic = row.getString("topic"); + if (topic == null) { + throw new IllegalArgumentException("Expected the topic to be not null"); + } + Integer partition = row.getInt32("partition"); + if (partition == null) { + throw new IllegalArgumentException("Expected the partition to be not null"); + } + return new TopicPartition(topic, partition); + }) + .collect(Collectors.toList()); + transform = transform.withTopicPartitions(Lists.newArrayList(topicPartitions)); + } + String topicPattern = configRow.getString("topic_pattern"); + if (topicPattern != null) { + transform = transform.withTopicPattern(topicPattern); + } + + byte[] keyDeserializerProvider = configRow.getBytes("key_deserializer_provider"); + if (keyDeserializerProvider != null) { + + byte[] keyCoder = configRow.getBytes("key_coder"); + if (keyCoder != null) { + transform = + transform.withKeyDeserializerProviderAndCoder( + (DeserializerProvider) fromByteArray(keyDeserializerProvider), + (org.apache.beam.sdk.coders.Coder) fromByteArray(keyCoder)); + } else { + transform = + transform.withKeyDeserializer( + (DeserializerProvider) fromByteArray(keyDeserializerProvider)); + } + } + + byte[] valueDeserializerProvider = configRow.getBytes("value_deserializer_provider"); + if (valueDeserializerProvider != null) { + byte[] valueCoder = configRow.getBytes("value_coder"); + if (valueCoder != null) { + transform = + transform.withValueDeserializerProviderAndCoder( + (DeserializerProvider) fromByteArray(valueDeserializerProvider), + (org.apache.beam.sdk.coders.Coder) fromByteArray(valueCoder)); + } else { + transform = + transform.withValueDeserializer( + (DeserializerProvider) fromByteArray(valueDeserializerProvider)); + } + } + + byte[] consumerFactoryFn = configRow.getBytes("consumer_factory_fn"); + if (consumerFactoryFn != null) { + transform = + transform.withConsumerFactoryFn( + (SerializableFunction, Consumer>) + fromByteArray(consumerFactoryFn)); + } + byte[] watermarkFn = configRow.getBytes("watermark_fn"); + if (watermarkFn != null) { + transform = transform.withWatermarkFn2((SerializableFunction) fromByteArray(watermarkFn)); + } + Long maxNumRecords = configRow.getInt64("max_num_records"); + if (maxNumRecords != null) { + transform = transform.withMaxNumRecords(maxNumRecords); + } + Duration maxReadTime = configRow.getValue("max_read_time"); + if (maxReadTime != null) { + transform = + transform.withMaxReadTime(org.joda.time.Duration.millis(maxReadTime.toMillis())); + } + Instant startReadTime = configRow.getValue("start_read_time"); + if (startReadTime != null) { + transform = transform.withStartReadTime(startReadTime); + } + Instant stopReadTime = configRow.getValue("stop_read_time"); + if (stopReadTime != null) { + transform = transform.withStopReadTime(stopReadTime); + } + Boolean isCommitOffsetFinalizeEnabled = + configRow.getBoolean("is_commit_offset_finalize_enabled"); + if (isCommitOffsetFinalizeEnabled != null && isCommitOffsetFinalizeEnabled) { + transform = transform.commitOffsetsInFinalize(); + } + Boolean isDynamicRead = configRow.getBoolean("is_dynamic_read"); + if (isDynamicRead != null && isDynamicRead) { + Duration watchTopicPartitionDuration = configRow.getValue("watch_topic_partition_duration"); + if (watchTopicPartitionDuration == null) { + throw new IllegalArgumentException( + "Expected watchTopicPartitionDuration to be available when isDynamicRead is set to true"); + } + transform = + transform.withDynamicRead( + org.joda.time.Duration.millis(watchTopicPartitionDuration.toMillis())); + } + + byte[] timestampPolicyFactory = configRow.getBytes("timestamp_policy_factory"); + if (timestampPolicyFactory != null) { + transform = + transform.withTimestampPolicyFactory( + (TimestampPolicyFactory) fromByteArray(timestampPolicyFactory)); + } + Map offsetConsumerConfig = configRow.getMap("offset_consumer_config"); + if (offsetConsumerConfig != null) { + Map updatedOffsetConsumerConfig = new HashMap<>(); + offsetConsumerConfig.forEach( + (key, dataBytes) -> { + if (offsetConsumerConfig.get(key) == null) { + throw new IllegalArgumentException( + "Encoded value for the offset consumer config key " + key + " was null."); + } + updatedOffsetConsumerConfig.put(key, fromByteArray(offsetConsumerConfig.get(key))); + }); + transform = transform.withOffsetConsumerConfigOverrides(updatedOffsetConsumerConfig); + } + + byte[] checkStopReadinfFn = configRow.getBytes("check_stop_reading_fn"); + if (checkStopReadinfFn != null) { + transform = + transform.withCheckStopReadingFn( + (SerializableFunction) fromByteArray(checkStopReadinfFn)); + } + + return transform; + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class ReadRegistrar implements TransformPayloadTranslatorRegistrar { + + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map, ? extends TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap., TransformPayloadTranslator>builder() + .put(Read.AUTOVALUE_CLASS, new KafkaIOReadWithMetadataTranslator()) + .build(); + } + } + + static class KafkaIOWriteTranslator implements TransformPayloadTranslator> { + + static Schema schema = + Schema.builder() + .addStringField("bootstrap_servers") + .addNullableStringField("topic") + .addNullableByteArrayField("key_serializer") + .addNullableByteArrayField("value_serializer") + .addNullableByteArrayField("producer_factory_fn") + .addNullableByteArrayField("publish_timestamp_fn") + .addBooleanField("eos") + .addInt32Field("num_shards") + .addNullableStringField("sink_group_id") + .addNullableByteArrayField("consumer_factory_fn") + .addNullableMapField("producer_config", FieldType.STRING, FieldType.BYTES) + .build(); + + @Override + public String getUrn() { + return KAFKA_WRITE_TRANSFORM_URN_V2; + } + + @Override + public String getUrn(Write transform) { + return TransformPayloadTranslator.super.getUrn(transform); + } + + @Override + @SuppressWarnings({ + "type.argument", + }) + public @Nullable FunctionSpec translate( + AppliedPTransform> application, SdkComponents components) + throws IOException { + { + // Setting an empty payload since Kafka transform payload is not actually used by runners + // currently. + // This can be implemented if runners started actually using the Kafka transform payload. + return FunctionSpec.newBuilder().setUrn(getUrn()).setPayload(ByteString.empty()).build(); + } + } + + @Override + public Row toConfigRow(Write transform) { + Map fieldValues = new HashMap<>(); + + WriteRecords writeRecordsTransform = transform.getWriteRecordsTransform(); + + if (!writeRecordsTransform + .getProducerConfig() + .containsKey(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG)) { + throw new IllegalArgumentException( + "Expected the producer config to have 'ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG' set. Found: " + + writeRecordsTransform.getProducerConfig()); + } + fieldValues.put( + "bootstrap_servers", + writeRecordsTransform.getProducerConfig().get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG)); + if (writeRecordsTransform.getTopic() != null) { + fieldValues.put("topic", writeRecordsTransform.getTopic()); + } + if (writeRecordsTransform.getKeySerializer() != null) { + fieldValues.put("key_serializer", toByteArray(writeRecordsTransform.getKeySerializer())); + } + if (writeRecordsTransform.getValueSerializer() != null) { + fieldValues.put( + "value_serializer", toByteArray(writeRecordsTransform.getValueSerializer())); + } + if (writeRecordsTransform.getProducerFactoryFn() != null) { + fieldValues.put( + "producer_factory_fn", toByteArray(writeRecordsTransform.getProducerFactoryFn())); + } + if (writeRecordsTransform.getPublishTimestampFunction() != null) { + fieldValues.put( + "publish_timestamp_fn", + toByteArray(writeRecordsTransform.getPublishTimestampFunction())); + } + + fieldValues.put("eos", writeRecordsTransform.isEOS()); + fieldValues.put("num_shards", writeRecordsTransform.getNumShards()); + + if (writeRecordsTransform.getSinkGroupId() != null) { + fieldValues.put("sink_group_id", writeRecordsTransform.getSinkGroupId()); + } + if (writeRecordsTransform.getConsumerFactoryFn() != null) { + fieldValues.put( + "consumer_factory_fn", toByteArray(writeRecordsTransform.getConsumerFactoryFn())); + } + + if (writeRecordsTransform.getProducerConfig().size() > 0) { + Map producerConfigMap = new HashMap<>(); + writeRecordsTransform + .getProducerConfig() + .forEach( + (key, value) -> { + producerConfigMap.put((String) key, toByteArray(value)); + }); + fieldValues.put("producer_config", producerConfigMap); + } + + return Row.withSchema(schema).withFieldValues(fieldValues).build(); + } + + @Override + public Write fromConfigRow(Row configRow) { + Write transform = KafkaIO.write(); + + String bootstrapServers = configRow.getString("bootstrap_servers"); + if (bootstrapServers != null) { + transform = transform.withBootstrapServers(bootstrapServers); + } + String topic = configRow.getValue("topic"); + if (topic != null) { + transform = transform.withTopic(topic); + } + byte[] keySerializerBytes = configRow.getBytes("key_serializer"); + if (keySerializerBytes != null) { + transform = transform.withKeySerializer((Class) fromByteArray(keySerializerBytes)); + } + byte[] valueSerializerBytes = configRow.getBytes("value_serializer"); + if (valueSerializerBytes != null) { + transform = transform.withValueSerializer((Class) fromByteArray(valueSerializerBytes)); + } + byte[] producerFactoryFnBytes = configRow.getBytes("producer_factory_fn"); + if (producerFactoryFnBytes != null) { + transform = + transform.withProducerFactoryFn( + (SerializableFunction) fromByteArray(producerFactoryFnBytes)); + } + Boolean isEOS = configRow.getBoolean("eos"); + if (isEOS != null && isEOS) { + Integer numShards = configRow.getInt32("num_shards"); + String sinkGroupId = configRow.getString("sink_group_id"); + if (numShards == null) { + throw new IllegalArgumentException( + "Expected numShards to be provided when EOS is set to true"); + } + if (sinkGroupId == null) { + throw new IllegalArgumentException( + "Expected sinkGroupId to be provided when EOS is set to true"); + } + transform = transform.withEOS(numShards, sinkGroupId); + } + byte[] consumerFactoryFnBytes = configRow.getBytes("consumer_factory_fn"); + if (consumerFactoryFnBytes != null) { + transform = + transform.withConsumerFactoryFn( + (SerializableFunction) fromByteArray(consumerFactoryFnBytes)); + } + + Map producerConfig = configRow.getMap("producer_config"); + if (producerConfig != null && !producerConfig.isEmpty()) { + Map updatedProducerConfig = new HashMap<>(); + producerConfig.forEach( + (key, dataBytes) -> { + updatedProducerConfig.put(key, fromByteArray((byte[]) dataBytes)); + }); + transform = transform.withProducerConfigUpdates(updatedProducerConfig); + } + + return transform; + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class WriteRegistrar implements TransformPayloadTranslatorRegistrar { + + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map, ? extends TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap., TransformPayloadTranslator>builder() + .put(Write.AUTOVALUE_CLASS, new KafkaIOWriteTranslator()) + .build(); + } + } +} diff --git a/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/package-info.java b/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/package-info.java new file mode 100644 index 000000000000..d19a4e630d73 --- /dev/null +++ b/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** A library to support upgrading Kafka transforms without upgrading the pipeline. */ +package org.apache.beam.sdk.io.kafka.upgrade; diff --git a/sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java b/sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java new file mode 100644 index 000000000000..d99bee0ad209 --- /dev/null +++ b/sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka.upgrade; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.beam.runners.core.construction.TransformUpgrader; +import org.apache.beam.sdk.io.kafka.KafkaIO; +import org.apache.beam.sdk.io.kafka.KafkaIO.Read; +import org.apache.beam.sdk.io.kafka.KafkaIO.Write; +import org.apache.beam.sdk.io.kafka.KafkaIO.WriteRecords; +import org.apache.beam.sdk.io.kafka.upgrade.KafkaIOTranslation.KafkaIOReadWithMetadataTranslator; +import org.apache.beam.sdk.io.kafka.upgrade.KafkaIOTranslation.KafkaIOWriteTranslator; +import org.apache.beam.sdk.values.Row; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.TopicPartition; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for KafkaIOTranslation. */ +@RunWith(JUnit4.class) +public class KafkaIOTranslationTest { + + // A mapping from Read transform builder methods to the corresponding schema fields in + // KafkaIOTranslation. + static final Map READ_TRANSFORM_SCHEMA_MAPPING = new HashMap<>(); + + static { + READ_TRANSFORM_SCHEMA_MAPPING.put("getConsumerConfig", "consumer_config"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getTopics", "topics"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getTopicPartitions", "topic_partitions"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getTopicPattern", "topic_pattern"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getKeyCoder", "key_coder"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getValueCoder", "value_coder"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getConsumerFactoryFn", "consumer_factory_fn"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getWatermarkFn", "watermark_fn"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getMaxNumRecords", "max_num_records"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getMaxReadTime", "max_read_time"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getStartReadTime", "start_read_time"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getStopReadTime", "stop_read_time"); + READ_TRANSFORM_SCHEMA_MAPPING.put( + "isCommitOffsetsInFinalizeEnabled", "is_commit_offset_finalize_enabled"); + READ_TRANSFORM_SCHEMA_MAPPING.put("isDynamicRead", "is_dynamic_read"); + READ_TRANSFORM_SCHEMA_MAPPING.put( + "getWatchTopicPartitionDuration", "watch_topic_partition_duration"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getTimestampPolicyFactory", "timestamp_policy_factory"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getOffsetConsumerConfig", "offset_consumer_config"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getKeyDeserializerProvider", "key_deserializer_provider"); + READ_TRANSFORM_SCHEMA_MAPPING.put( + "getValueDeserializerProvider", "value_deserializer_provider"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getCheckStopReadingFn", "check_stop_reading_fn"); + } + + // A mapping from Write transform builder methods to the corresponding schema fields in + // KafkaIOTranslation. + static final Map WRITE_TRANSFORM_SCHEMA_MAPPING = new HashMap<>(); + + static { + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getTopic", "topic"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getProducerConfig", "producer_config"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getProducerFactoryFn", "producer_factory_fn"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getKeySerializer", "key_serializer"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getValueSerializer", "value_serializer"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getPublishTimestampFunction", "publish_timestamp_fn"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("isEOS", "eos"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getSinkGroupId", "sink_group_id"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getNumShards", "num_shards"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getConsumerFactoryFn", "consumer_factory_fn"); + } + + @Test + public void testReCreateReadTransformFromRow() throws Exception { + Map consumerConfig = new HashMap<>(); + consumerConfig.put("dummyconfig", "dummyvalue"); + + Read readTransform = + KafkaIO.read() + .withBootstrapServers("dummykafkaserver") + .withTopicPartitions( + IntStream.range(0, 1) + .mapToObj(i -> new TopicPartition("dummytopic", i)) + .collect(Collectors.toList())) + .withConsumerConfigUpdates(consumerConfig); + KafkaIOTranslation.KafkaIOReadWithMetadataTranslator translator = + new KafkaIOReadWithMetadataTranslator(); + Row row = translator.toConfigRow(readTransform); + + Read readTransformFromRow = + (Read) translator.fromConfigRow(row); + assertNotNull( + readTransformFromRow.getConsumerConfig().get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG)); + assertEquals( + "dummykafkaserver", + readTransformFromRow.getConsumerConfig().get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG)); + assertEquals(1, readTransformFromRow.getTopicPartitions().size()); + assertEquals("dummytopic", readTransformFromRow.getTopicPartitions().get(0).topic()); + assertEquals(0, readTransformFromRow.getTopicPartitions().get(0).partition()); + } + + @Test + public void testReadTransformRowIncludesAllFields() throws Exception { + List getMethodNames = + Arrays.stream(Read.class.getDeclaredMethods()) + .map( + method -> { + return method.getName(); + }) + .filter(methodName -> methodName.startsWith("get")) + .collect(Collectors.toList()); + + // Just to make sure that this does not pass trivially. + assertTrue(getMethodNames.size() > 0); + + for (String getMethodName : getMethodNames) { + assertTrue( + "Method " + + getMethodName + + " will not be tracked when upgrading the 'KafkaIO.Read' transform. Please update " + + "'KafkaIOTranslation.KafkaIOReadWithMetadataTranslator' to track the new method " + + "and update this test.", + READ_TRANSFORM_SCHEMA_MAPPING.keySet().contains(getMethodName)); + } + + // Confirming that all fields mentioned in `readTransformMethodNameToSchemaFieldMapping` are + // actually available in the schema. + READ_TRANSFORM_SCHEMA_MAPPING.values().stream() + .forEach( + fieldName -> { + assertTrue( + "Field name " + + fieldName + + " was not found in the read transform schema defined in " + + "KafkaIOReadWithMetadataTranslator.", + KafkaIOReadWithMetadataTranslator.schema.getFieldNames().contains(fieldName)); + }); + } + + @Test + public void testReCreateWriteTransformFromRow() throws Exception { + Map producerConfig = new HashMap<>(); + producerConfig.put("dummyconfig", "dummyvalue"); + Write writeTransform = + KafkaIO.write() + .withBootstrapServers("dummybootstrapserver") + .withTopic("dummytopic") + .withProducerConfigUpdates(producerConfig); + KafkaIOTranslation.KafkaIOWriteTranslator translator = + new KafkaIOTranslation.KafkaIOWriteTranslator(); + Row row = translator.toConfigRow(writeTransform); + + Write writeTransformFromRow = + (Write) translator.fromConfigRow(row); + WriteRecords writeRecordsTransform = + writeTransformFromRow.getWriteRecordsTransform(); + assertNotNull( + writeRecordsTransform.getProducerConfig().get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG)); + assertEquals( + "dummybootstrapserver", + writeRecordsTransform.getProducerConfig().get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG)); + assertEquals("dummytopic", writeRecordsTransform.getTopic()); + Map producerConfigFromRow = writeRecordsTransform.getProducerConfig(); + assertTrue(producerConfigFromRow.containsKey("dummyconfig")); + assertEquals("dummyvalue", producerConfigFromRow.get("dummyconfig")); + } + + @Test + public void testWriteTransformRowIncludesAllFields() throws Exception { + // Write transform delegates property handling to the WriteRecords class. So we inspect the + // WriteRecords class here. + List getMethodNames = + Arrays.stream(WriteRecords.class.getDeclaredMethods()) + .map( + method -> { + return method.getName(); + }) + .filter(methodName -> methodName.startsWith("get")) + .collect(Collectors.toList()); + + // Just to make sure that this does not pass trivially. + assertTrue(getMethodNames.size() > 0); + + for (String getMethodName : getMethodNames) { + assertTrue( + "Method " + + getMethodName + + " will not be tracked when upgrading the 'KafkaIO.Write' transform. Please update " + + "'KafkaIOTranslation.KafkaIOWriteTranslator' to track the new method and update " + + "this test.", + WRITE_TRANSFORM_SCHEMA_MAPPING.keySet().contains(getMethodName)); + } + + // Confirming that all fields mentioned in `writeTransformMethodNameToSchemaFieldMapping` are + // actually available in the schema. + WRITE_TRANSFORM_SCHEMA_MAPPING.values().stream() + .forEach( + fieldName -> { + assertTrue( + "Field name " + + fieldName + + " was not found in the write transform schema defined in " + + "KafkaIOWriteWithMetadataTranslator.", + KafkaIOWriteTranslator.schema.getFieldNames().contains(fieldName)); + }); + } + + @Test + public void testReadTransformURNDiscovery() { + Read readTransform = + KafkaIO.read() + .withBootstrapServers("dummykafkaserver") + .withTopicPartitions( + IntStream.range(0, 1) + .mapToObj(i -> new TopicPartition("dummytopic", i)) + .collect(Collectors.toList())); + + assertEquals( + KafkaIOTranslation.KAFKA_READ_WITH_METADATA_TRANSFORM_URN_V2, + TransformUpgrader.findUpgradeURN(readTransform)); + } + + @Test + public void testWriteTransformURNDiscovery() { + Write writeTransform = + KafkaIO.write() + .withBootstrapServers("dummybootstrapserver") + .withTopic("dummytopic"); + + assertEquals( + KafkaIOTranslation.KAFKA_WRITE_TRANSFORM_URN_V2, + TransformUpgrader.findUpgradeURN(writeTransform)); + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 85ef793b8e12..03e3edf0b2de 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -247,6 +247,7 @@ include(":sdks:java:io:jdbc") include(":sdks:java:io:jms") include(":sdks:java:io:json") include(":sdks:java:io:kafka") +include(":sdks:java:io:kafka:upgrade") include(":sdks:java:io:kinesis") include(":sdks:java:io:kinesis:expansion-service") include(":sdks:java:io:kudu")