-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(pgbulk): add retry with filtered df for pg output write (#10)
- Loading branch information
1 parent
9ce2a5f
commit 2ee9176
Showing
5 changed files
with
173 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGToolTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package fr.aphp.id.eds.requester.cohort.pg | ||
|
||
import org.apache.commons.io.FileUtils | ||
import org.apache.hadoop.fs.Path | ||
import org.apache.spark.sql.SparkSession | ||
import org.apache.spark.sql.functions.col | ||
import org.scalatest.BeforeAndAfterAll | ||
import org.scalatest.flatspec.AnyFlatSpec | ||
import org.scalatest.matchers.should.Matchers | ||
import org.testcontainers.containers.PostgreSQLContainer | ||
import org.scalatest.funsuite.AnyFunSuiteLike | ||
|
||
import java.nio.file.{Files, Path} | ||
|
||
class PGToolTest extends AnyFunSuiteLike with Matchers with BeforeAndAfterAll { | ||
val sparkSession: SparkSession = SparkSession | ||
.builder() | ||
.master("local[*]") | ||
.getOrCreate() | ||
private var tempDir: java.nio.file.Path = _ | ||
private val postgresContainer = new PostgreSQLContainer("postgres:15.3") | ||
|
||
override def beforeAll(): Unit = { | ||
super.beforeAll() | ||
tempDir = Files.createTempDirectory("test-temp-dir") | ||
postgresContainer.start() | ||
postgresContainer.withPassword("test") | ||
postgresContainer.withUsername("test") | ||
val pgPassFile = tempDir.resolve(".pgpass") | ||
Files.write(pgPassFile, s"${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}:*:${postgresContainer.getUsername}:${postgresContainer.getPassword}".getBytes) | ||
} | ||
|
||
override def afterAll(): Unit = { | ||
super.afterAll() | ||
FileUtils.deleteDirectory(tempDir.toFile) | ||
postgresContainer.stop() | ||
} | ||
|
||
test("testOutputBulk") { | ||
import sparkSession.implicits._ | ||
val pgUrl = s"jdbc:postgresql://${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}/${postgresContainer.getDatabaseName}?user=${postgresContainer.getUsername}¤tSchema=public" | ||
val pgTool = PGTool(sparkSession, pgUrl, tempDir.toString, pgPassFile = new org.apache.hadoop.fs.Path(tempDir.resolve(".pgpass").toString)) | ||
val createTableQuery = """ | ||
CREATE TABLE test_table ( | ||
id INT PRIMARY KEY, | ||
value TEXT, | ||
id_2 INT | ||
) | ||
""" | ||
pgTool.sqlExec(createTableQuery) | ||
|
||
val insertDataQuery = """ | ||
INSERT INTO test_table (id, value, id_2) VALUES | ||
(1, '1', 1), | ||
(2, '2', 2) | ||
""" | ||
pgTool.sqlExec(insertDataQuery) | ||
val baseContent = pgTool.sqlExecWithResult("select * from test_table") | ||
baseContent.collect().map(_.getInt(0)) should contain theSameElementsAs Array(1, 2) | ||
|
||
// generate a new dataframe containing 100 elements with 2 columns id and value that will be written to the database | ||
val data = sparkSession.range(100).toDF("id").withColumn("value", 'id.cast("string")).withColumn("id_2", col("id")) | ||
pgTool.outputBulk("test_table", data, primaryKeys = Seq("id", "id_2")) | ||
val updatedContent = pgTool.sqlExecWithResult("select * from test_table") | ||
updatedContent.collect().map(_.getInt(0)) should contain theSameElementsAs (0 until 100) | ||
} | ||
|
||
} |