From 95a08451c58cf95a1c844e17e08832532a796bd7 Mon Sep 17 00:00:00 2001 From: Ravi Magham Date: Tue, 3 Dec 2024 14:59:05 -0800 Subject: [PATCH] jdbcio write batch config changes (#33205) * jdbcio write batch config changes * fixup: lint fixes * fixup: format fixes * fixup: support batch size for yaml based pipeline * fixup: review comments --------- Co-authored-by: Ravi Magham Co-authored-by: PoojaS2010 --- .../beam/sdk/io/jdbc/JdbcSchemaIOProvider.java | 5 +++++ .../jdbc/JdbcWriteSchemaTransformProvider.java | 11 +++++++++++ .../org/apache/beam/sdk/io/jdbc/JdbcIOTest.java | 16 ++++++++++++++++ .../sdk/io/jdbc/JdbcSchemaIOProviderTest.java | 2 ++ .../JdbcWriteSchemaTransformProviderTest.java | 1 + sdks/python/apache_beam/io/jdbc.py | 9 ++++++++- sdks/python/apache_beam/yaml/standard_io.yaml | 1 + 7 files changed, 44 insertions(+), 1 deletion(-) diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java index 30012465eb9e..11034aee1cdf 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java @@ -68,6 +68,7 @@ public Schema configurationSchema() { .addNullableField("disableAutoCommit", FieldType.BOOLEAN) .addNullableField("outputParallelization", FieldType.BOOLEAN) .addNullableField("autosharding", FieldType.BOOLEAN) + .addNullableField("writeBatchSize", FieldType.INT64) // Partitioning support. If you specify a partition column we will use that instead of // readQuery .addNullableField("partitionColumn", FieldType.STRING) @@ -194,6 +195,10 @@ public PDone expand(PCollection input) { if (autosharding != null && autosharding) { writeRows = writeRows.withAutoSharding(); } + @Nullable Long writeBatchSize = config.getInt64("writeBatchSize"); + if (writeBatchSize != null) { + writeRows = writeRows.withBatchSize(writeBatchSize); + } return input.apply(writeRows); } }; diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java index a409b604b11f..1f970ba0624f 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java @@ -141,6 +141,12 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { if (autosharding != null && autosharding) { writeRows = writeRows.withAutoSharding(); } + + Long writeBatchSize = config.getBatchSize(); + if (writeBatchSize != null) { + writeRows = writeRows.withBatchSize(writeBatchSize); + } + PCollection postWrite = input .get("input") @@ -205,6 +211,9 @@ public abstract static class JdbcWriteSchemaTransformConfiguration implements Se @Nullable public abstract String getDriverJars(); + @Nullable + public abstract Long getBatchSize(); + public void validate() throws IllegalArgumentException { if (Strings.isNullOrEmpty(getJdbcUrl())) { throw new IllegalArgumentException("JDBC URL cannot be blank"); @@ -268,6 +277,8 @@ public abstract Builder setConnectionInitSql( public abstract Builder setDriverJars(String value); + public abstract Builder setBatchSize(Long value); + public abstract JdbcWriteSchemaTransformConfiguration build(); } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index a04f8c4e762f..8725ef4b3f78 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -549,6 +549,22 @@ public void testWrite() throws Exception { } } + @Test + public void testWriteWithBatchSize() throws Exception { + String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); + DatabaseTestHelper.createTable(DATA_SOURCE, tableName); + try { + ArrayList> data = getDataToWrite(EXPECTED_ROW_COUNT); + pipeline.apply(Create.of(data)).apply(getJdbcWrite(tableName).withBatchSize(10L)); + + pipeline.run(); + + assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT); + } finally { + DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName); + } + } + @Test public void testWriteWithAutosharding() throws Exception { String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java index ed380d813625..193a1f0c3477 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java @@ -133,6 +133,7 @@ public void testAbleToReadDataSourceConfiguration() { .withFieldValue("connectionInitSqls", new ArrayList<>(Collections.singleton("initSql"))) .withFieldValue("maxConnections", (short) 3) .withFieldValue("driverJars", "test.jar") + .withFieldValue("writeBatchSize", 10L) .build(); JdbcSchemaIOProvider.JdbcSchemaIO schemaIO = provider.from(READ_TABLE_NAME, config, Schema.builder().build()); @@ -148,6 +149,7 @@ public void testAbleToReadDataSourceConfiguration() { Objects.requireNonNull(dataSourceConf.getConnectionInitSqls()).get()); assertEquals(3, (int) dataSourceConf.getMaxConnections().get()); assertEquals("test.jar", Objects.requireNonNull(dataSourceConf.getDriverJars()).get()); + assertEquals(10L, schemaIO.config.getInt64("writeBatchSize").longValue()); } /** Create test data that is consistent with that generated by TestRow. */ diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java index f66a143323e5..d6be4d9f89c8 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java @@ -175,6 +175,7 @@ public void testWriteToTable() throws SQLException { .setDriverClassName(DATA_SOURCE_CONFIGURATION.getDriverClassName().get()) .setJdbcUrl(DATA_SOURCE_CONFIGURATION.getUrl().get()) .setLocation(writeTableName) + .setBatchSize(1L) .build())); pipeline.run(); DatabaseTestHelper.assertRowCount(DATA_SOURCE, writeTableName, 2); diff --git a/sdks/python/apache_beam/io/jdbc.py b/sdks/python/apache_beam/io/jdbc.py index d4ece0c7bc29..11570680a2f3 100644 --- a/sdks/python/apache_beam/io/jdbc.py +++ b/sdks/python/apache_beam/io/jdbc.py @@ -131,7 +131,8 @@ def default_io_expansion_service(classpath=None): ('partition_column', typing.Optional[str]), ('partitions', typing.Optional[np.int16]), ('max_connections', typing.Optional[np.int16]), - ('driver_jars', typing.Optional[str])], + ('driver_jars', typing.Optional[str]), + ('write_batch_size', typing.Optional[np.int64])], ) DEFAULT_JDBC_CLASSPATH = ['org.postgresql:postgresql:42.2.16'] @@ -187,6 +188,7 @@ def __init__( driver_jars=None, expansion_service=None, classpath=None, + write_batch_size=None, ): """ Initializes a write operation to Jdbc. @@ -218,6 +220,9 @@ def __init__( package (e.g. "org.postgresql:postgresql:42.3.1"). By default, this argument includes a Postgres SQL JDBC driver. + :param write_batch_size: sets the maximum size in number of SQL statement + for the batch. + default is {@link JdbcIO.DEFAULT_BATCH_SIZE} """ classpath = classpath or DEFAULT_JDBC_CLASSPATH super().__init__( @@ -235,6 +240,7 @@ def __init__( connection_properties=connection_properties, connection_init_sqls=connection_init_sqls, write_statement=statement, + write_batch_size=write_batch_size, read_query=None, fetch_size=None, disable_autocommit=None, @@ -352,6 +358,7 @@ def __init__( connection_properties=connection_properties, connection_init_sqls=connection_init_sqls, write_statement=None, + write_batch_size=None, read_query=query, fetch_size=fetch_size, disable_autocommit=disable_autocommit, diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml b/sdks/python/apache_beam/yaml/standard_io.yaml index 269c14e17baa..305e6877ad90 100644 --- a/sdks/python/apache_beam/yaml/standard_io.yaml +++ b/sdks/python/apache_beam/yaml/standard_io.yaml @@ -225,6 +225,7 @@ driver_jars: 'driver_jars' connection_properties: 'connection_properties' connection_init_sql: 'connection_init_sql' + batch_size: 'batch_size' 'ReadFromMySql': 'ReadFromJdbc' 'WriteToMySql': 'WriteToJdbc' 'ReadFromPostgres': 'ReadFromJdbc'