Skip to content

Commit

Permalink
feat: add sampled cohort creation (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
pl-buiquang authored Nov 25, 2024
1 parent 47e8a26 commit 9ce2a5f
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 10 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ The job query format is as follows :
"cohortDefinitionSyntax": "<cohort definition syntax>",
"mode": "<mode>",
"modeOptions": { // optional mode options
// list of criteria ids separated by commas or "all", this will activate a detailed count of the patients per criteria
// "ratio", this will activate a detailed count of final matched patients per criteria
"details": "<details>"
// optional list of criteria ids separated by commas or "all", this will activate a detailed count of the patients per criteria
// or "ratio", this will activate a detailed count of final matched patients per criteria
"details": "<details>",
// optional sampling ratio value between 0.0 and 1.0 to limit the number of patients of the cohort to create (it can be used to sample an existing cohort)
"sampling": "<sampling>"
},
"callbackUrl": "<callback url>" // optional callback url to retrieve the result
}
Expand Down
11 changes: 11 additions & 0 deletions src/main/scala/fr/aphp/id/eds/requester/CreateQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import fr.aphp.id.eds.requester.tools.{JobUtils, JobUtilsService, StageDetails}
import org.apache.log4j.Logger
import org.apache.spark.sql.{SparkSession, functions => F}

object CreateOptions extends Enumeration {
type CreateOptions = String
val sampling = "sampling"
}

case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(),
jobUtilsService: JobUtilsService = JobUtils)
extends JobBase {
Expand Down Expand Up @@ -68,6 +73,12 @@ case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(),
.map(c => F.col(c)): _*)
.dropDuplicates()

if (data.modeOptions.contains(CreateOptions.sampling)) {
val sampling = data.modeOptions(CreateOptions.sampling).toDouble
// https://stackoverflow.com/questions/37416825/dataframe-sample-in-apache-spark-scala#comment62349780_37418684
// to be sure to have the right number of rows
cohort = cohort.sample(sampling+0.1).limit((sampling * cohort.count()).round.toInt)
}
cohort.cache()
count = cohort.count()
val cohortSizeBiggerThanLimit = count > LIMIT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ case class SparkJobParameter(
solrRows: String = "10000",
commitWithin: String = "10000",
mode: String = JobType.count,
// see fr.aphp.id.eds.requester.{CountQuery, CreateQuery} for the usage of modeOptions
modeOptions: Map[String, String] = Map.empty,
cohortUuid: Option[String] = Option.empty,
existingCohortId: Option[Long] = Option.empty,
Expand Down
58 changes: 51 additions & 7 deletions src/test/scala/fr/aphp/id/eds/requester/CreateQueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package fr.aphp.id.eds.requester
import fr.aphp.id.eds.requester.cohort.CohortCreation
import fr.aphp.id.eds.requester.jobs.{JobEnv, JobsConfig, SparkJobParameter}
import fr.aphp.id.eds.requester.query.engine.QueryBuilder
import fr.aphp.id.eds.requester.query.model.SourcePopulation
import fr.aphp.id.eds.requester.query.resolver.{ResourceResolver, ResourceResolvers}
import fr.aphp.id.eds.requester.tools.JobUtilsService
import org.apache.spark.sql.SparkSession
import org.mockito.ArgumentMatchersSugar
import org.mockito.MockitoSugar.{mock, when}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.mockito.{ArgumentCaptor, ArgumentMatchersSugar}
import org.mockito.MockitoSugar.{mock, spy, verify, when}
import org.scalatest.funsuite.AnyFunSuiteLike

import java.nio.file.Paths
Expand Down Expand Up @@ -60,7 +61,7 @@ class CreateQueryTest extends AnyFunSuiteLike {
.getOrCreate()

val queryBuilderMock = mock[QueryBuilder]
val omopTools = mock[CohortCreation]
val omopTools = spy(mock[CohortCreation])
val resourceResolver = ResourceResolver.get(ResourceResolvers.solr)
class JobUtilsMock extends JobUtilsService {
override def getRandomIdNotInTabooList(allTabooId: List[Short], negative: Boolean): Short = 99
Expand Down Expand Up @@ -128,7 +129,7 @@ class CreateQueryTest extends AnyFunSuiteLike {
}
assert(error.getMessage == "Non-patient resource filter request should be a basic resource")

val expected = getClass.getResource(s"/testCases/simple/expected.csv")
val expected = getClass.getResource(s"/testCases/stageRatioDetails/expected.csv")
val expectedResult = sparkSession.read
.format("csv")
.option("delimiter", ";")
Expand All @@ -146,6 +147,14 @@ class CreateQueryTest extends AnyFunSuiteLike {
ArgumentMatchersSugar.*
)
).thenReturn(expectedResult)
when(omopTools.createCohort(
ArgumentMatchersSugar.eqTo("testCohortSimple"),
ArgumentMatchersSugar.any[Option[String]],
ArgumentMatchersSugar.any[String],
ArgumentMatchersSugar.any[String],
ArgumentMatchersSugar.any[String],
ArgumentMatchersSugar.any[Long]
)).thenReturn(0)
val request =
"""
{"cohortUuid":"ecd89963-ac90-470d-a397-c846882615a6","sourcePopulation":{"caresiteCohortList":[31558]},"_type":"request","request":{"_type":"andGroup","_id":0,"isInclusive":true,"criteria":[{"_type":"basicResource","_id":1,"isInclusive":true,"resourceType":"patientAphp","filterSolr":"fq=gender:f&fq=deceased:false&fq=active:true","filterFhir":"active=true&gender=f&deceased=false&age-day=ge0&age-day=le130"}],"temporalConstraints":[]}}"
Expand All @@ -154,15 +163,50 @@ class CreateQueryTest extends AnyFunSuiteLike {
sparkSession,
JobEnv("someid", AppConfig.get),
SparkJobParameter(
"testCohort",
"testCohortSimple",
None,
request,
"someOwnerId"
)
)
assert(res.status == "FINISHED")
assert(res.data("group.count") == "2")
assert(res.data("group.count") == "6")
assert(res.data("group.id") == "0")

when(omopTools.createCohort(
ArgumentMatchersSugar.eqTo("testCohortSampling"),
ArgumentMatchersSugar.any[Option[String]],
ArgumentMatchersSugar.any[String],
ArgumentMatchersSugar.any[String],
ArgumentMatchersSugar.any[String],
ArgumentMatchersSugar.any[Long]
)).thenReturn(1L)
val sampled = createJob.runJob(
sparkSession,
JobEnv("someid", AppConfig.get),
SparkJobParameter(
"testCohortSampling",
None,
request,
"someOwnerId",
modeOptions = Map("sampling" -> "0.33")
)
)
val omopToolsCaptor = ArgumentCaptor.forClass(classOf[org.apache.spark.sql.DataFrame])
verify(omopTools).updateCohort(
ArgumentMatchersSugar.eqTo(1L),
omopToolsCaptor.capture(),
ArgumentMatchersSugar.any[SourcePopulation],
ArgumentMatchersSugar.any[Long],
ArgumentMatchersSugar.any[Boolean],
ArgumentMatchersSugar.any[String]
)
val capturedDataFrame: DataFrame = omopToolsCaptor.getValue
assert(capturedDataFrame.columns.contains("subject_id"))
assert(capturedDataFrame.count() >= 1 && capturedDataFrame.count() <= 2)
assert(sampled.status == "FINISHED")
assert(sampled.data("group.count").toInt >= 1 && sampled.data("group.count").toInt <= 2)
assert(sampled.data("group.id") == "1")
}

}

0 comments on commit 9ce2a5f

Please sign in to comment.