Skip to content

Commit

Permalink
jdbcio write batch config changes (#33205)
Browse files Browse the repository at this point in the history
* jdbcio write batch config changes

* fixup: lint fixes

* fixup: format fixes

* fixup: support batch size for yaml based pipeline

* fixup: review comments

---------

Co-authored-by: Ravi Magham <[email protected]>
Co-authored-by: PoojaS2010 <[email protected]>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent 2fe725d commit 95a0845
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public Schema configurationSchema() {
.addNullableField("disableAutoCommit", FieldType.BOOLEAN)
.addNullableField("outputParallelization", FieldType.BOOLEAN)
.addNullableField("autosharding", FieldType.BOOLEAN)
.addNullableField("writeBatchSize", FieldType.INT64)
// Partitioning support. If you specify a partition column we will use that instead of
// readQuery
.addNullableField("partitionColumn", FieldType.STRING)
Expand Down Expand Up @@ -194,6 +195,10 @@ public PDone expand(PCollection<Row> input) {
if (autosharding != null && autosharding) {
writeRows = writeRows.withAutoSharding();
}
@Nullable Long writeBatchSize = config.getInt64("writeBatchSize");
if (writeBatchSize != null) {
writeRows = writeRows.withBatchSize(writeBatchSize);
}
return input.apply(writeRows);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
if (autosharding != null && autosharding) {
writeRows = writeRows.withAutoSharding();
}

Long writeBatchSize = config.getBatchSize();
if (writeBatchSize != null) {
writeRows = writeRows.withBatchSize(writeBatchSize);
}

PCollection<Row> postWrite =
input
.get("input")
Expand Down Expand Up @@ -205,6 +211,9 @@ public abstract static class JdbcWriteSchemaTransformConfiguration implements Se
@Nullable
public abstract String getDriverJars();

@Nullable
public abstract Long getBatchSize();

public void validate() throws IllegalArgumentException {
if (Strings.isNullOrEmpty(getJdbcUrl())) {
throw new IllegalArgumentException("JDBC URL cannot be blank");
Expand Down Expand Up @@ -268,6 +277,8 @@ public abstract Builder setConnectionInitSql(

public abstract Builder setDriverJars(String value);

public abstract Builder setBatchSize(Long value);

public abstract JdbcWriteSchemaTransformConfiguration build();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,22 @@ public void testWrite() throws Exception {
}
}

@Test
public void testWriteWithBatchSize() throws Exception {
String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
DatabaseTestHelper.createTable(DATA_SOURCE, tableName);
try {
ArrayList<KV<Integer, String>> data = getDataToWrite(EXPECTED_ROW_COUNT);
pipeline.apply(Create.of(data)).apply(getJdbcWrite(tableName).withBatchSize(10L));

pipeline.run();

assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT);
} finally {
DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName);
}
}

@Test
public void testWriteWithAutosharding() throws Exception {
String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ public void testAbleToReadDataSourceConfiguration() {
.withFieldValue("connectionInitSqls", new ArrayList<>(Collections.singleton("initSql")))
.withFieldValue("maxConnections", (short) 3)
.withFieldValue("driverJars", "test.jar")
.withFieldValue("writeBatchSize", 10L)
.build();
JdbcSchemaIOProvider.JdbcSchemaIO schemaIO =
provider.from(READ_TABLE_NAME, config, Schema.builder().build());
Expand All @@ -148,6 +149,7 @@ public void testAbleToReadDataSourceConfiguration() {
Objects.requireNonNull(dataSourceConf.getConnectionInitSqls()).get());
assertEquals(3, (int) dataSourceConf.getMaxConnections().get());
assertEquals("test.jar", Objects.requireNonNull(dataSourceConf.getDriverJars()).get());
assertEquals(10L, schemaIO.config.getInt64("writeBatchSize").longValue());
}

/** Create test data that is consistent with that generated by TestRow. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ public void testWriteToTable() throws SQLException {
.setDriverClassName(DATA_SOURCE_CONFIGURATION.getDriverClassName().get())
.setJdbcUrl(DATA_SOURCE_CONFIGURATION.getUrl().get())
.setLocation(writeTableName)
.setBatchSize(1L)
.build()));
pipeline.run();
DatabaseTestHelper.assertRowCount(DATA_SOURCE, writeTableName, 2);
Expand Down
9 changes: 8 additions & 1 deletion sdks/python/apache_beam/io/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def default_io_expansion_service(classpath=None):
('partition_column', typing.Optional[str]),
('partitions', typing.Optional[np.int16]),
('max_connections', typing.Optional[np.int16]),
('driver_jars', typing.Optional[str])],
('driver_jars', typing.Optional[str]),
('write_batch_size', typing.Optional[np.int64])],
)

DEFAULT_JDBC_CLASSPATH = ['org.postgresql:postgresql:42.2.16']
Expand Down Expand Up @@ -187,6 +188,7 @@ def __init__(
driver_jars=None,
expansion_service=None,
classpath=None,
write_batch_size=None,
):
"""
Initializes a write operation to Jdbc.
Expand Down Expand Up @@ -218,6 +220,9 @@ def __init__(
package (e.g. "org.postgresql:postgresql:42.3.1").
By default, this argument includes a Postgres SQL JDBC
driver.
:param write_batch_size: sets the maximum size in number of SQL statement
for the batch.
default is {@link JdbcIO.DEFAULT_BATCH_SIZE}
"""
classpath = classpath or DEFAULT_JDBC_CLASSPATH
super().__init__(
Expand All @@ -235,6 +240,7 @@ def __init__(
connection_properties=connection_properties,
connection_init_sqls=connection_init_sqls,
write_statement=statement,
write_batch_size=write_batch_size,
read_query=None,
fetch_size=None,
disable_autocommit=None,
Expand Down Expand Up @@ -352,6 +358,7 @@ def __init__(
connection_properties=connection_properties,
connection_init_sqls=connection_init_sqls,
write_statement=None,
write_batch_size=None,
read_query=query,
fetch_size=fetch_size,
disable_autocommit=disable_autocommit,
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/yaml/standard_io.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@
driver_jars: 'driver_jars'
connection_properties: 'connection_properties'
connection_init_sql: 'connection_init_sql'
batch_size: 'batch_size'
'ReadFromMySql': 'ReadFromJdbc'
'WriteToMySql': 'WriteToJdbc'
'ReadFromPostgres': 'ReadFromJdbc'
Expand Down

0 comments on commit 95a0845

Please sign in to comment.