diff --git a/pom.xml b/pom.xml index e8c09ef..ca6e6be 100644 --- a/pom.xml +++ b/pom.xml @@ -31,6 +31,7 @@ 2.6.5 4.0.4 3.3.1 + 1.20.4 true @@ -296,6 +297,12 @@ ${wiremock.version} test + + org.testcontainers + postgresql + ${postgrestest.version} + test + @@ -446,7 +453,9 @@ - --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED + + --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED + diff --git a/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala b/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala index 146a0f2..afddc6a 100644 --- a/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala +++ b/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala @@ -50,8 +50,8 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging { } /** - * This loads both a cohort and its definition into postgres and solr - */ + * This loads both a cohort and its definition into postgres and solr + */ override def updateCohort(cohortId: Long, cohort: DataFrame, sourcePopulation: SourcePopulation, @@ -68,9 +68,9 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging { .withColumnRenamed(ResultColumn.SUBJECT, "_itemreferenceid") .withColumn("item__reference", concat(lit(s"${resourceType}/"), col("_itemreferenceid"))) .select(F.col("_itemreferenceid"), - F.col("item__reference"), - F.col("_provider"), - F.col("_listid")) + F.col("item__reference"), + F.col("_provider"), + F.col("_listid")) uploadCohortTableToPG(dataframe) @@ -172,7 +172,10 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging { ).toSet == df.columns.toSet, "cohort dataframe shall have _listid, _provider, _provider and item__reference" ) - pg.outputBulk(cohort_item_table_rw, dfAddHash(df), Some(4)) + pg.outputBulk(cohort_item_table_rw, + dfAddHash(df), + Some(4), + primaryKeys = Seq("_listid", "_itemreferenceid", "_provider")) } /** diff --git a/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGTools.scala b/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGTools.scala index 9f8478d..799695a 100755 --- a/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGTools.scala +++ b/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGTools.scala @@ -14,6 +14,7 @@ import java.io._ import java.sql._ import java.util.Properties import java.util.UUID.randomUUID +import scala.util.{Failure, Success, Try} sealed trait BulkLoadMode @@ -35,8 +36,8 @@ class PGTool( private var password: String = "" - def setPassword(pwd: String = ""): PGTool = { - password = PGTool.passwordFromConn(url, pwd) + def setPassword(pwd: String): PGTool = { + password = pwd this } @@ -73,7 +74,8 @@ class PGTool( table: String, df: Dataset[Row], numPartitions: Option[Int] = None, - reindex: Boolean = false + reindex: Boolean = false, + primaryKeys: Seq[String] = Seq.empty ): PGTool = { PGTool.outputBulk( spark, @@ -84,7 +86,8 @@ class PGTool( numPartitions.getOrElse(8), password, reindex, - bulkLoadBufferSize + bulkLoadBufferSize, + primaryKeys = primaryKeys ) this } @@ -106,7 +109,8 @@ object PGTool extends java.io.Serializable with LazyLogging { url: String, tmpPath: String, bulkLoadMode: BulkLoadMode = defaultBulkLoadStrategy, - bulkLoadBufferSize: Int = defaultBulkLoadBufferSize + bulkLoadBufferSize: Int = defaultBulkLoadBufferSize, + pgPassFile: Path = new Path(scala.sys.env("HOME"), ".pgpass") ): PGTool = { new PGTool( spark, @@ -114,37 +118,35 @@ object PGTool extends java.io.Serializable with LazyLogging { tmpPath + "/spark-postgres-" + randomUUID.toString, bulkLoadMode, bulkLoadBufferSize - ).setPassword() + ).setPassword(passwordFromConn(url, pgPassFile)) } - def connOpen(url: String, password: String = ""): Connection = { + def connOpen(url: String, password: String): Connection = { val prop = new Properties() prop.put("driver", "org.postgresql.Driver") - prop.put("password", passwordFromConn(url, password)) + prop.put("password", password) DriverManager.getConnection(url, prop) } - def passwordFromConn(url: String, password: String): String = { - if (password.nonEmpty) { - return password - } + def passwordFromConn(url: String, pgPassFile: Path): String = { val pattern = "jdbc:postgresql://(.*):(\\d+)/(\\w+)[?]user=(\\w+).*".r val pattern(host, port, database, username) = url - dbPassword(host, port, database, username) + dbPassword(host, port, database, username, pgPassFile) } private def dbPassword( hostname: String, port: String, database: String, - username: String + username: String, + pgPassFile: Path ): String = { // Usage: val thatPassWord = dbPassword(hostname,port,database,username) // .pgpass file format, hostname:port:database:username:password val fs = FileSystem.get(new java.net.URI("file:///"), new Configuration) val reader = new BufferedReader( - new InputStreamReader(fs.open(new Path(scala.sys.env("HOME"), ".pgpass"))) + new InputStreamReader(fs.open(pgPassFile)) ) val content = Iterator .continually(reader.readLine()) @@ -185,7 +187,9 @@ object PGTool extends java.io.Serializable with LazyLogging { numPartitions: Int = 8, password: String = "", reindex: Boolean = false, - bulkLoadBufferSize: Int = defaultBulkLoadBufferSize + bulkLoadBufferSize: Int = defaultBulkLoadBufferSize, + withRetry: Boolean = true, + primaryKeys: Seq[String] = Seq.empty ): Unit = { logger.debug("using CSV strategy") try { @@ -208,7 +212,7 @@ object PGTool extends java.io.Serializable with LazyLogging { .mode(org.apache.spark.sql.SaveMode.Overwrite) .save(path) - outputBulkCsvLow( + val success = outputBulkCsvLow( spark, url, table, @@ -220,6 +224,45 @@ object PGTool extends java.io.Serializable with LazyLogging { password, bulkLoadBufferSize ) + if (!success) { + if (!withRetry) { + throw new Exception("Bulk load failed") + } else { + logger.warn( + "Bulk load failed, retrying with filtering existing items" + ) + // try again with filtering the original dataframe with existing items + val selectedColumns = if (primaryKeys.isEmpty) "*" else primaryKeys.map(sanP).mkString(",") + val existingItems = sqlExecWithResult( + spark, + url, + s"SELECT $selectedColumns FROM $table", + password + ) + val existingItemsSet = existingItems.collect().map(_.mkString(",")).toSet + val dfWithSelectedColumns = if (primaryKeys.isEmpty) df else df.select(primaryKeys.map(col): _*) + val dfFiltered = dfWithSelectedColumns + .filter( + row => + !existingItemsSet.contains( + row.mkString(",") + ) + ) + outputBulk( + spark, + url, + table, + dfFiltered, + path, + numPartitions, + password, + reindex, + bulkLoadBufferSize, + withRetry = false + ) + } + } + } finally { if (reindex) indexReactivate(url, table, password) @@ -235,7 +278,7 @@ object PGTool extends java.io.Serializable with LazyLogging { numPartitions: Int = 8, password: String = "", bulkLoadBufferSize: Int = defaultBulkLoadBufferSize - ): Unit = { + ): Boolean = { // load the csv files from hdfs in parallel val fs = FileSystem.get(new Configuration()) @@ -251,22 +294,31 @@ object PGTool extends java.io.Serializable with LazyLogging { .rdd .partitionBy(new ExactPartitioner(numPartitions)) - rdd.foreachPartition(x => { + val statusRdd = rdd.mapPartitions(x => { val conn = connOpen(url, password) - x.foreach { s => - { - val stream: InputStream = FileSystem - .get(new Configuration()) - .open(new Path(s._2)) - .getWrappedStream - val copyManager: CopyManager = - new CopyManager(conn.asInstanceOf[BaseConnection]) - copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize) - } + val res = Try { + x.map { s => + { + val stream: InputStream = FileSystem + .get(new Configuration()) + .open(new Path(s._2)) + .getWrappedStream + val copyManager: CopyManager = + new CopyManager(conn.asInstanceOf[BaseConnection]) + copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize) + } + }.toList } conn.close() - x.toIterator + res match { + case Success(_) => Iterator(true) // Partition succeeded + case Failure(error) => { + logger.error("Partition output loading failed", error) + Iterator(false) // Partition failed + } + } }) + !statusRdd.collect().contains(false) } def outputBulkCsvLow( @@ -280,7 +332,7 @@ object PGTool extends java.io.Serializable with LazyLogging { extensionPattern: String = ".*.csv", password: String = "", bulkLoadBufferSize: Int = defaultBulkLoadBufferSize - ): Unit = { + ): Boolean = { val csvSqlCopy = s"""COPY "$table" ($columns) FROM STDIN WITH CSV DELIMITER '$delimiter' NULL '' ESCAPE '"' QUOTE '"' """ outputBulkFileLow( @@ -301,7 +353,7 @@ object PGTool extends java.io.Serializable with LazyLogging { schema } - def indexDeactivate(url: String, table: String, password: String = ""): Unit = { + def indexDeactivate(url: String, table: String, password: String): Unit = { val schema = getSchema(url) val query = s""" @@ -316,7 +368,7 @@ object PGTool extends java.io.Serializable with LazyLogging { logger.debug(s"Deactivating indexes from $schema.$table") } - def indexReactivate(url: String, table: String, password: String = ""): Unit = { + def indexReactivate(url: String, table: String, password: String): Unit = { val schema = getSchema(url) val query = diff --git a/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala b/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala index 0414104..ece9eaf 100644 --- a/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala +++ b/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala @@ -114,6 +114,7 @@ class PGCohortCreationTest extends AnyFunSuiteLike with DatasetComparer { ArgumentMatchers.eq("list__entry_cohort360"), df.capture(), ArgumentMatchers.eq(Some(4)), + ArgumentMatchersSugar.*, ArgumentMatchersSugar.* ) assertSmallDatasetEquality(df.getValue.asInstanceOf[DataFrame], expectedDf) diff --git a/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGToolTest.scala b/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGToolTest.scala new file mode 100644 index 0000000..295f4cb --- /dev/null +++ b/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGToolTest.scala @@ -0,0 +1,68 @@ +package fr.aphp.id.eds.requester.cohort.pg + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.col +import org.scalatest.BeforeAndAfterAll +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.testcontainers.containers.PostgreSQLContainer +import org.scalatest.funsuite.AnyFunSuiteLike + +import java.nio.file.{Files, Path} + +class PGToolTest extends AnyFunSuiteLike with Matchers with BeforeAndAfterAll { + val sparkSession: SparkSession = SparkSession + .builder() + .master("local[*]") + .getOrCreate() + private var tempDir: java.nio.file.Path = _ + private val postgresContainer = new PostgreSQLContainer("postgres:15.3") + + override def beforeAll(): Unit = { + super.beforeAll() + tempDir = Files.createTempDirectory("test-temp-dir") + postgresContainer.start() + postgresContainer.withPassword("test") + postgresContainer.withUsername("test") + val pgPassFile = tempDir.resolve(".pgpass") + Files.write(pgPassFile, s"${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}:*:${postgresContainer.getUsername}:${postgresContainer.getPassword}".getBytes) + } + + override def afterAll(): Unit = { + super.afterAll() + FileUtils.deleteDirectory(tempDir.toFile) + postgresContainer.stop() + } + + test("testOutputBulk") { + import sparkSession.implicits._ + val pgUrl = s"jdbc:postgresql://${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}/${postgresContainer.getDatabaseName}?user=${postgresContainer.getUsername}¤tSchema=public" + val pgTool = PGTool(sparkSession, pgUrl, tempDir.toString, pgPassFile = new org.apache.hadoop.fs.Path(tempDir.resolve(".pgpass").toString)) + val createTableQuery = """ + CREATE TABLE test_table ( + id INT PRIMARY KEY, + value TEXT, + id_2 INT + ) + """ + pgTool.sqlExec(createTableQuery) + + val insertDataQuery = """ + INSERT INTO test_table (id, value, id_2) VALUES + (1, '1', 1), + (2, '2', 2) + """ + pgTool.sqlExec(insertDataQuery) + val baseContent = pgTool.sqlExecWithResult("select * from test_table") + baseContent.collect().map(_.getInt(0)) should contain theSameElementsAs Array(1, 2) + + // generate a new dataframe containing 100 elements with 2 columns id and value that will be written to the database + val data = sparkSession.range(100).toDF("id").withColumn("value", 'id.cast("string")).withColumn("id_2", col("id")) + pgTool.outputBulk("test_table", data, primaryKeys = Seq("id", "id_2")) + val updatedContent = pgTool.sqlExecWithResult("select * from test_table") + updatedContent.collect().map(_.getInt(0)) should contain theSameElementsAs (0 until 100) + } + +}