diff --git a/pom.xml b/pom.xml
index e8c09ef..ca6e6be 100644
--- a/pom.xml
+++ b/pom.xml
@@ -31,6 +31,7 @@
2.6.5
4.0.4
3.3.1
+ 1.20.4
true
@@ -296,6 +297,12 @@
${wiremock.version}
test
+
+ org.testcontainers
+ postgresql
+ ${postgrestest.version}
+ test
+
@@ -446,7 +453,9 @@
- --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED
+
+ --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED
+
diff --git a/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala b/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala
index 146a0f2..afddc6a 100644
--- a/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala
+++ b/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala
@@ -50,8 +50,8 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
}
/**
- * This loads both a cohort and its definition into postgres and solr
- */
+ * This loads both a cohort and its definition into postgres and solr
+ */
override def updateCohort(cohortId: Long,
cohort: DataFrame,
sourcePopulation: SourcePopulation,
@@ -68,9 +68,9 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
.withColumnRenamed(ResultColumn.SUBJECT, "_itemreferenceid")
.withColumn("item__reference", concat(lit(s"${resourceType}/"), col("_itemreferenceid")))
.select(F.col("_itemreferenceid"),
- F.col("item__reference"),
- F.col("_provider"),
- F.col("_listid"))
+ F.col("item__reference"),
+ F.col("_provider"),
+ F.col("_listid"))
uploadCohortTableToPG(dataframe)
@@ -172,7 +172,10 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
).toSet == df.columns.toSet,
"cohort dataframe shall have _listid, _provider, _provider and item__reference"
)
- pg.outputBulk(cohort_item_table_rw, dfAddHash(df), Some(4))
+ pg.outputBulk(cohort_item_table_rw,
+ dfAddHash(df),
+ Some(4),
+ primaryKeys = Seq("_listid", "_itemreferenceid", "_provider"))
}
/**
diff --git a/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGTools.scala b/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGTools.scala
index 9f8478d..799695a 100755
--- a/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGTools.scala
+++ b/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGTools.scala
@@ -14,6 +14,7 @@ import java.io._
import java.sql._
import java.util.Properties
import java.util.UUID.randomUUID
+import scala.util.{Failure, Success, Try}
sealed trait BulkLoadMode
@@ -35,8 +36,8 @@ class PGTool(
private var password: String = ""
- def setPassword(pwd: String = ""): PGTool = {
- password = PGTool.passwordFromConn(url, pwd)
+ def setPassword(pwd: String): PGTool = {
+ password = pwd
this
}
@@ -73,7 +74,8 @@ class PGTool(
table: String,
df: Dataset[Row],
numPartitions: Option[Int] = None,
- reindex: Boolean = false
+ reindex: Boolean = false,
+ primaryKeys: Seq[String] = Seq.empty
): PGTool = {
PGTool.outputBulk(
spark,
@@ -84,7 +86,8 @@ class PGTool(
numPartitions.getOrElse(8),
password,
reindex,
- bulkLoadBufferSize
+ bulkLoadBufferSize,
+ primaryKeys = primaryKeys
)
this
}
@@ -106,7 +109,8 @@ object PGTool extends java.io.Serializable with LazyLogging {
url: String,
tmpPath: String,
bulkLoadMode: BulkLoadMode = defaultBulkLoadStrategy,
- bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
+ bulkLoadBufferSize: Int = defaultBulkLoadBufferSize,
+ pgPassFile: Path = new Path(scala.sys.env("HOME"), ".pgpass")
): PGTool = {
new PGTool(
spark,
@@ -114,37 +118,35 @@ object PGTool extends java.io.Serializable with LazyLogging {
tmpPath + "/spark-postgres-" + randomUUID.toString,
bulkLoadMode,
bulkLoadBufferSize
- ).setPassword()
+ ).setPassword(passwordFromConn(url, pgPassFile))
}
- def connOpen(url: String, password: String = ""): Connection = {
+ def connOpen(url: String, password: String): Connection = {
val prop = new Properties()
prop.put("driver", "org.postgresql.Driver")
- prop.put("password", passwordFromConn(url, password))
+ prop.put("password", password)
DriverManager.getConnection(url, prop)
}
- def passwordFromConn(url: String, password: String): String = {
- if (password.nonEmpty) {
- return password
- }
+ def passwordFromConn(url: String, pgPassFile: Path): String = {
val pattern = "jdbc:postgresql://(.*):(\\d+)/(\\w+)[?]user=(\\w+).*".r
val pattern(host, port, database, username) = url
- dbPassword(host, port, database, username)
+ dbPassword(host, port, database, username, pgPassFile)
}
private def dbPassword(
hostname: String,
port: String,
database: String,
- username: String
+ username: String,
+ pgPassFile: Path
): String = {
// Usage: val thatPassWord = dbPassword(hostname,port,database,username)
// .pgpass file format, hostname:port:database:username:password
val fs = FileSystem.get(new java.net.URI("file:///"), new Configuration)
val reader = new BufferedReader(
- new InputStreamReader(fs.open(new Path(scala.sys.env("HOME"), ".pgpass")))
+ new InputStreamReader(fs.open(pgPassFile))
)
val content = Iterator
.continually(reader.readLine())
@@ -185,7 +187,9 @@ object PGTool extends java.io.Serializable with LazyLogging {
numPartitions: Int = 8,
password: String = "",
reindex: Boolean = false,
- bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
+ bulkLoadBufferSize: Int = defaultBulkLoadBufferSize,
+ withRetry: Boolean = true,
+ primaryKeys: Seq[String] = Seq.empty
): Unit = {
logger.debug("using CSV strategy")
try {
@@ -208,7 +212,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
.mode(org.apache.spark.sql.SaveMode.Overwrite)
.save(path)
- outputBulkCsvLow(
+ val success = outputBulkCsvLow(
spark,
url,
table,
@@ -220,6 +224,45 @@ object PGTool extends java.io.Serializable with LazyLogging {
password,
bulkLoadBufferSize
)
+ if (!success) {
+ if (!withRetry) {
+ throw new Exception("Bulk load failed")
+ } else {
+ logger.warn(
+ "Bulk load failed, retrying with filtering existing items"
+ )
+ // try again with filtering the original dataframe with existing items
+ val selectedColumns = if (primaryKeys.isEmpty) "*" else primaryKeys.map(sanP).mkString(",")
+ val existingItems = sqlExecWithResult(
+ spark,
+ url,
+ s"SELECT $selectedColumns FROM $table",
+ password
+ )
+ val existingItemsSet = existingItems.collect().map(_.mkString(",")).toSet
+ val dfWithSelectedColumns = if (primaryKeys.isEmpty) df else df.select(primaryKeys.map(col): _*)
+ val dfFiltered = dfWithSelectedColumns
+ .filter(
+ row =>
+ !existingItemsSet.contains(
+ row.mkString(",")
+ )
+ )
+ outputBulk(
+ spark,
+ url,
+ table,
+ dfFiltered,
+ path,
+ numPartitions,
+ password,
+ reindex,
+ bulkLoadBufferSize,
+ withRetry = false
+ )
+ }
+ }
+
} finally {
if (reindex)
indexReactivate(url, table, password)
@@ -235,7 +278,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
numPartitions: Int = 8,
password: String = "",
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
- ): Unit = {
+ ): Boolean = {
// load the csv files from hdfs in parallel
val fs = FileSystem.get(new Configuration())
@@ -251,22 +294,31 @@ object PGTool extends java.io.Serializable with LazyLogging {
.rdd
.partitionBy(new ExactPartitioner(numPartitions))
- rdd.foreachPartition(x => {
+ val statusRdd = rdd.mapPartitions(x => {
val conn = connOpen(url, password)
- x.foreach { s =>
- {
- val stream: InputStream = FileSystem
- .get(new Configuration())
- .open(new Path(s._2))
- .getWrappedStream
- val copyManager: CopyManager =
- new CopyManager(conn.asInstanceOf[BaseConnection])
- copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize)
- }
+ val res = Try {
+ x.map { s =>
+ {
+ val stream: InputStream = FileSystem
+ .get(new Configuration())
+ .open(new Path(s._2))
+ .getWrappedStream
+ val copyManager: CopyManager =
+ new CopyManager(conn.asInstanceOf[BaseConnection])
+ copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize)
+ }
+ }.toList
}
conn.close()
- x.toIterator
+ res match {
+ case Success(_) => Iterator(true) // Partition succeeded
+ case Failure(error) => {
+ logger.error("Partition output loading failed", error)
+ Iterator(false) // Partition failed
+ }
+ }
})
+ !statusRdd.collect().contains(false)
}
def outputBulkCsvLow(
@@ -280,7 +332,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
extensionPattern: String = ".*.csv",
password: String = "",
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
- ): Unit = {
+ ): Boolean = {
val csvSqlCopy =
s"""COPY "$table" ($columns) FROM STDIN WITH CSV DELIMITER '$delimiter' NULL '' ESCAPE '"' QUOTE '"' """
outputBulkFileLow(
@@ -301,7 +353,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
schema
}
- def indexDeactivate(url: String, table: String, password: String = ""): Unit = {
+ def indexDeactivate(url: String, table: String, password: String): Unit = {
val schema = getSchema(url)
val query =
s"""
@@ -316,7 +368,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
logger.debug(s"Deactivating indexes from $schema.$table")
}
- def indexReactivate(url: String, table: String, password: String = ""): Unit = {
+ def indexReactivate(url: String, table: String, password: String): Unit = {
val schema = getSchema(url)
val query =
diff --git a/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala b/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala
index 0414104..ece9eaf 100644
--- a/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala
+++ b/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala
@@ -114,6 +114,7 @@ class PGCohortCreationTest extends AnyFunSuiteLike with DatasetComparer {
ArgumentMatchers.eq("list__entry_cohort360"),
df.capture(),
ArgumentMatchers.eq(Some(4)),
+ ArgumentMatchersSugar.*,
ArgumentMatchersSugar.*
)
assertSmallDatasetEquality(df.getValue.asInstanceOf[DataFrame], expectedDf)
diff --git a/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGToolTest.scala b/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGToolTest.scala
new file mode 100644
index 0000000..295f4cb
--- /dev/null
+++ b/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGToolTest.scala
@@ -0,0 +1,68 @@
+package fr.aphp.id.eds.requester.cohort.pg
+
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.fs.Path
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.functions.col
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.flatspec.AnyFlatSpec
+import org.scalatest.matchers.should.Matchers
+import org.testcontainers.containers.PostgreSQLContainer
+import org.scalatest.funsuite.AnyFunSuiteLike
+
+import java.nio.file.{Files, Path}
+
+class PGToolTest extends AnyFunSuiteLike with Matchers with BeforeAndAfterAll {
+ val sparkSession: SparkSession = SparkSession
+ .builder()
+ .master("local[*]")
+ .getOrCreate()
+ private var tempDir: java.nio.file.Path = _
+ private val postgresContainer = new PostgreSQLContainer("postgres:15.3")
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ tempDir = Files.createTempDirectory("test-temp-dir")
+ postgresContainer.start()
+ postgresContainer.withPassword("test")
+ postgresContainer.withUsername("test")
+ val pgPassFile = tempDir.resolve(".pgpass")
+ Files.write(pgPassFile, s"${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}:*:${postgresContainer.getUsername}:${postgresContainer.getPassword}".getBytes)
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ FileUtils.deleteDirectory(tempDir.toFile)
+ postgresContainer.stop()
+ }
+
+ test("testOutputBulk") {
+ import sparkSession.implicits._
+ val pgUrl = s"jdbc:postgresql://${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}/${postgresContainer.getDatabaseName}?user=${postgresContainer.getUsername}¤tSchema=public"
+ val pgTool = PGTool(sparkSession, pgUrl, tempDir.toString, pgPassFile = new org.apache.hadoop.fs.Path(tempDir.resolve(".pgpass").toString))
+ val createTableQuery = """
+ CREATE TABLE test_table (
+ id INT PRIMARY KEY,
+ value TEXT,
+ id_2 INT
+ )
+ """
+ pgTool.sqlExec(createTableQuery)
+
+ val insertDataQuery = """
+ INSERT INTO test_table (id, value, id_2) VALUES
+ (1, '1', 1),
+ (2, '2', 2)
+ """
+ pgTool.sqlExec(insertDataQuery)
+ val baseContent = pgTool.sqlExecWithResult("select * from test_table")
+ baseContent.collect().map(_.getInt(0)) should contain theSameElementsAs Array(1, 2)
+
+ // generate a new dataframe containing 100 elements with 2 columns id and value that will be written to the database
+ val data = sparkSession.range(100).toDF("id").withColumn("value", 'id.cast("string")).withColumn("id_2", col("id"))
+ pgTool.outputBulk("test_table", data, primaryKeys = Seq("id", "id_2"))
+ val updatedContent = pgTool.sqlExecWithResult("select * from test_table")
+ updatedContent.collect().map(_.getInt(0)) should contain theSameElementsAs (0 until 100)
+ }
+
+}