Skip to content

Commit

Permalink
feat(pgbulk): add retry with filtered df for pg output write (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
pl-buiquang authored Jan 6, 2025
1 parent 9ce2a5f commit 2ee9176
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 40 deletions.
11 changes: 10 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
<hadoop.version>2.6.5</hadoop.version>
<spark-solr.version>4.0.4</spark-solr.version>
<wiremock.version>3.3.1</wiremock.version>
<postgrestest.version>1.20.4</postgrestest.version>
<!-- Sonar -->
<sonar.qualitygate.wait>true</sonar.qualitygate.wait>

Expand Down Expand Up @@ -296,6 +297,12 @@
<version>${wiremock.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>postgresql</artifactId>
<version>${postgrestest.version}</version>
<scope>test</scope>
</dependency>



Expand Down Expand Up @@ -446,7 +453,9 @@
</execution>
</executions>
<configuration>
<argLine>--add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED</argLine>
<argLine>
--add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED
</argLine>
</configuration>
</plugin>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
}

/**
* This loads both a cohort and its definition into postgres and solr
*/
* This loads both a cohort and its definition into postgres and solr
*/
override def updateCohort(cohortId: Long,
cohort: DataFrame,
sourcePopulation: SourcePopulation,
Expand All @@ -68,9 +68,9 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
.withColumnRenamed(ResultColumn.SUBJECT, "_itemreferenceid")
.withColumn("item__reference", concat(lit(s"${resourceType}/"), col("_itemreferenceid")))
.select(F.col("_itemreferenceid"),
F.col("item__reference"),
F.col("_provider"),
F.col("_listid"))
F.col("item__reference"),
F.col("_provider"),
F.col("_listid"))

uploadCohortTableToPG(dataframe)

Expand Down Expand Up @@ -172,7 +172,10 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
).toSet == df.columns.toSet,
"cohort dataframe shall have _listid, _provider, _provider and item__reference"
)
pg.outputBulk(cohort_item_table_rw, dfAddHash(df), Some(4))
pg.outputBulk(cohort_item_table_rw,
dfAddHash(df),
Some(4),
primaryKeys = Seq("_listid", "_itemreferenceid", "_provider"))
}

/**
Expand Down
118 changes: 85 additions & 33 deletions src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGTools.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import java.io._
import java.sql._
import java.util.Properties
import java.util.UUID.randomUUID
import scala.util.{Failure, Success, Try}

sealed trait BulkLoadMode

Expand All @@ -35,8 +36,8 @@ class PGTool(

private var password: String = ""

def setPassword(pwd: String = ""): PGTool = {
password = PGTool.passwordFromConn(url, pwd)
def setPassword(pwd: String): PGTool = {
password = pwd
this
}

Expand Down Expand Up @@ -73,7 +74,8 @@ class PGTool(
table: String,
df: Dataset[Row],
numPartitions: Option[Int] = None,
reindex: Boolean = false
reindex: Boolean = false,
primaryKeys: Seq[String] = Seq.empty
): PGTool = {
PGTool.outputBulk(
spark,
Expand All @@ -84,7 +86,8 @@ class PGTool(
numPartitions.getOrElse(8),
password,
reindex,
bulkLoadBufferSize
bulkLoadBufferSize,
primaryKeys = primaryKeys
)
this
}
Expand All @@ -106,45 +109,44 @@ object PGTool extends java.io.Serializable with LazyLogging {
url: String,
tmpPath: String,
bulkLoadMode: BulkLoadMode = defaultBulkLoadStrategy,
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize,
pgPassFile: Path = new Path(scala.sys.env("HOME"), ".pgpass")
): PGTool = {
new PGTool(
spark,
url,
tmpPath + "/spark-postgres-" + randomUUID.toString,
bulkLoadMode,
bulkLoadBufferSize
).setPassword()
).setPassword(passwordFromConn(url, pgPassFile))
}

def connOpen(url: String, password: String = ""): Connection = {
def connOpen(url: String, password: String): Connection = {
val prop = new Properties()
prop.put("driver", "org.postgresql.Driver")
prop.put("password", passwordFromConn(url, password))
prop.put("password", password)
DriverManager.getConnection(url, prop)
}

def passwordFromConn(url: String, password: String): String = {
if (password.nonEmpty) {
return password
}
def passwordFromConn(url: String, pgPassFile: Path): String = {
val pattern = "jdbc:postgresql://(.*):(\\d+)/(\\w+)[?]user=(\\w+).*".r
val pattern(host, port, database, username) = url
dbPassword(host, port, database, username)
dbPassword(host, port, database, username, pgPassFile)
}

private def dbPassword(
hostname: String,
port: String,
database: String,
username: String
username: String,
pgPassFile: Path
): String = {
// Usage: val thatPassWord = dbPassword(hostname,port,database,username)
// .pgpass file format, hostname:port:database:username:password

val fs = FileSystem.get(new java.net.URI("file:///"), new Configuration)
val reader = new BufferedReader(
new InputStreamReader(fs.open(new Path(scala.sys.env("HOME"), ".pgpass")))
new InputStreamReader(fs.open(pgPassFile))
)
val content = Iterator
.continually(reader.readLine())
Expand Down Expand Up @@ -185,7 +187,9 @@ object PGTool extends java.io.Serializable with LazyLogging {
numPartitions: Int = 8,
password: String = "",
reindex: Boolean = false,
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize,
withRetry: Boolean = true,
primaryKeys: Seq[String] = Seq.empty
): Unit = {
logger.debug("using CSV strategy")
try {
Expand All @@ -208,7 +212,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
.mode(org.apache.spark.sql.SaveMode.Overwrite)
.save(path)

outputBulkCsvLow(
val success = outputBulkCsvLow(
spark,
url,
table,
Expand All @@ -220,6 +224,45 @@ object PGTool extends java.io.Serializable with LazyLogging {
password,
bulkLoadBufferSize
)
if (!success) {
if (!withRetry) {
throw new Exception("Bulk load failed")
} else {
logger.warn(
"Bulk load failed, retrying with filtering existing items"
)
// try again with filtering the original dataframe with existing items
val selectedColumns = if (primaryKeys.isEmpty) "*" else primaryKeys.map(sanP).mkString(",")
val existingItems = sqlExecWithResult(
spark,
url,
s"SELECT $selectedColumns FROM $table",
password
)
val existingItemsSet = existingItems.collect().map(_.mkString(",")).toSet
val dfWithSelectedColumns = if (primaryKeys.isEmpty) df else df.select(primaryKeys.map(col): _*)
val dfFiltered = dfWithSelectedColumns
.filter(
row =>
!existingItemsSet.contains(
row.mkString(",")
)
)
outputBulk(
spark,
url,
table,
dfFiltered,
path,
numPartitions,
password,
reindex,
bulkLoadBufferSize,
withRetry = false
)
}
}

} finally {
if (reindex)
indexReactivate(url, table, password)
Expand All @@ -235,7 +278,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
numPartitions: Int = 8,
password: String = "",
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
): Unit = {
): Boolean = {

// load the csv files from hdfs in parallel
val fs = FileSystem.get(new Configuration())
Expand All @@ -251,22 +294,31 @@ object PGTool extends java.io.Serializable with LazyLogging {
.rdd
.partitionBy(new ExactPartitioner(numPartitions))

rdd.foreachPartition(x => {
val statusRdd = rdd.mapPartitions(x => {
val conn = connOpen(url, password)
x.foreach { s =>
{
val stream: InputStream = FileSystem
.get(new Configuration())
.open(new Path(s._2))
.getWrappedStream
val copyManager: CopyManager =
new CopyManager(conn.asInstanceOf[BaseConnection])
copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize)
}
val res = Try {
x.map { s =>
{
val stream: InputStream = FileSystem
.get(new Configuration())
.open(new Path(s._2))
.getWrappedStream
val copyManager: CopyManager =
new CopyManager(conn.asInstanceOf[BaseConnection])
copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize)
}
}.toList
}
conn.close()
x.toIterator
res match {
case Success(_) => Iterator(true) // Partition succeeded
case Failure(error) => {
logger.error("Partition output loading failed", error)
Iterator(false) // Partition failed
}
}
})
!statusRdd.collect().contains(false)
}

def outputBulkCsvLow(
Expand All @@ -280,7 +332,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
extensionPattern: String = ".*.csv",
password: String = "",
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
): Unit = {
): Boolean = {
val csvSqlCopy =
s"""COPY "$table" ($columns) FROM STDIN WITH CSV DELIMITER '$delimiter' NULL '' ESCAPE '"' QUOTE '"' """
outputBulkFileLow(
Expand All @@ -301,7 +353,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
schema
}

def indexDeactivate(url: String, table: String, password: String = ""): Unit = {
def indexDeactivate(url: String, table: String, password: String): Unit = {
val schema = getSchema(url)
val query =
s"""
Expand All @@ -316,7 +368,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
logger.debug(s"Deactivating indexes from $schema.$table")
}

def indexReactivate(url: String, table: String, password: String = ""): Unit = {
def indexReactivate(url: String, table: String, password: String): Unit = {

val schema = getSchema(url)
val query =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class PGCohortCreationTest extends AnyFunSuiteLike with DatasetComparer {
ArgumentMatchers.eq("list__entry_cohort360"),
df.capture(),
ArgumentMatchers.eq(Some(4)),
ArgumentMatchersSugar.*,
ArgumentMatchersSugar.*
)
assertSmallDatasetEquality(df.getValue.asInstanceOf[DataFrame], expectedDf)
Expand Down
68 changes: 68 additions & 0 deletions src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGToolTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package fr.aphp.id.eds.requester.cohort.pg

import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.testcontainers.containers.PostgreSQLContainer
import org.scalatest.funsuite.AnyFunSuiteLike

import java.nio.file.{Files, Path}

class PGToolTest extends AnyFunSuiteLike with Matchers with BeforeAndAfterAll {
val sparkSession: SparkSession = SparkSession
.builder()
.master("local[*]")
.getOrCreate()
private var tempDir: java.nio.file.Path = _
private val postgresContainer = new PostgreSQLContainer("postgres:15.3")

override def beforeAll(): Unit = {
super.beforeAll()
tempDir = Files.createTempDirectory("test-temp-dir")
postgresContainer.start()
postgresContainer.withPassword("test")
postgresContainer.withUsername("test")
val pgPassFile = tempDir.resolve(".pgpass")
Files.write(pgPassFile, s"${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}:*:${postgresContainer.getUsername}:${postgresContainer.getPassword}".getBytes)
}

override def afterAll(): Unit = {
super.afterAll()
FileUtils.deleteDirectory(tempDir.toFile)
postgresContainer.stop()
}

test("testOutputBulk") {
import sparkSession.implicits._
val pgUrl = s"jdbc:postgresql://${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}/${postgresContainer.getDatabaseName}?user=${postgresContainer.getUsername}&currentSchema=public"
val pgTool = PGTool(sparkSession, pgUrl, tempDir.toString, pgPassFile = new org.apache.hadoop.fs.Path(tempDir.resolve(".pgpass").toString))
val createTableQuery = """
CREATE TABLE test_table (
id INT PRIMARY KEY,
value TEXT,
id_2 INT
)
"""
pgTool.sqlExec(createTableQuery)

val insertDataQuery = """
INSERT INTO test_table (id, value, id_2) VALUES
(1, '1', 1),
(2, '2', 2)
"""
pgTool.sqlExec(insertDataQuery)
val baseContent = pgTool.sqlExecWithResult("select * from test_table")
baseContent.collect().map(_.getInt(0)) should contain theSameElementsAs Array(1, 2)

// generate a new dataframe containing 100 elements with 2 columns id and value that will be written to the database
val data = sparkSession.range(100).toDF("id").withColumn("value", 'id.cast("string")).withColumn("id_2", col("id"))
pgTool.outputBulk("test_table", data, primaryKeys = Seq("id", "id_2"))
val updatedContent = pgTool.sqlExecWithResult("select * from test_table")
updatedContent.collect().map(_.getInt(0)) should contain theSameElementsAs (0 until 100)
}

}

0 comments on commit 2ee9176

Please sign in to comment.