Skip to content
This repository has been archived by the owner on Feb 12, 2022. It is now read-only.

Commit

Permalink
Refactor the code to allow different dispatcher (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shwetank87 authored Jun 19, 2020
1 parent fe8edb0 commit 3d2fb6b
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 159 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ val cloudwatchMetrics = "io.github.azagniotov" % "dropwizard-metrics-clo

lazy val commonSettings = Seq(
scalacOptions ++= Seq("-deprecation", "-feature", "-Xlint", "-Xfatal-warnings"),
scalaVersion := "2.12.8",
scalaVersion := "2.12.10",
libraryDependencies += scalaTestArtifact,
organization := "com.krux",
test in assembly := {}, // skip test during assembly
Expand Down
10 changes: 10 additions & 0 deletions conf-template/starport.conf
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ krux.starport {
EXAMPLE_ENVIRONMENT_VAR_2 = "some-value"
}

dispatcher {
# valid values are default or sqs
type = "default"
# applicable if type is set to sqs
sqs {
publishQueueUrl = ""
retrieveQueueUrl = ""
}
}

# slack webhook to be used (optional)
slack_webhook_url = "???"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ import java.util.concurrent.{ForkJoinPool, TimeUnit}

import scala.collection.parallel.ForkJoinTaskSupport
import scala.concurrent.ExecutionContext.Implicits.global
import scala.sys.process._
import com.codahale.metrics.{Counter, MetricRegistry}
import com.github.nscala_time.time.Imports._
import slick.jdbc.PostgresProfile.api._
import com.krux.hyperion.client.{AwsClient, AwsClientForName}
import com.krux.hyperion.expression.{Duration => HDuration}
import com.krux.starport.cli.{SchedulerOptionParser, SchedulerOptions}
import com.krux.starport.db.record.{Pipeline, ScheduledPipeline, SchedulerMetric}
import com.krux.starport.db.table.{Pipelines, ScheduleFailureCounters, ScheduledPipelines, SchedulerMetrics}
import com.krux.starport.dispatcher.TaskDispatcher
import com.krux.starport.dispatcher.impl.TaskDispatcherImpl
import com.krux.starport.metric.{ConstantValueGauge, MetricSettings, SimpleTimerGauge}
import com.krux.starport.util.{ErrorHandler, S3FileHandler}

Expand Down Expand Up @@ -58,137 +58,47 @@ object StartScheduledPipelines extends StarportActivity {
result
}

/**
* @return status, the output, and the deployed pipeline name
*/
def deployPipeline(
pipelineRecord: Pipeline,
currentTime: DateTime,
currentEndTime: DateTime,
localJar: String
): (Int, String, String) = {

// TODO probably better to just use a different logger
val logPrefix = s"|PipelineId: ${pipelineRecord.id}|"

logger.info(s"$logPrefix Deploying pipeline: ${pipelineRecord.name}")

val start = pipelineRecord.nextRunTime.get
val until = pipelineRecord.end
.map(DateTimeOrdering.min(currentEndTime, _))
.getOrElse(currentEndTime)
val pipelinePeriod = pipelineRecord.period

// Note that aws datapieline have a weird requirement for endTime (documented as it has to be
// greater than startTime, but actually it has to be greater than startTime + period), it's
// very confusing, we here change it to number of times the pipeline should run to avoid this
// confusion.
val calculatedTimes =
if (pipelineRecord.backfill) timesTillEnd(start, until, HDuration(pipelinePeriod))
else 1

logger.info(s"$logPrefix calculatedTimes: $calculatedTimes")
if (calculatedTimes < 1) {
// the calculatedTimes should never be < 1
logger.error(s"calculatedTimes < 1")
private def updateScheduledPipelines(scheduledPipelines: Seq[ScheduledPipeline]) = {
scheduledPipelines.isEmpty match {
case true => ()
case false =>
val insertAction = DBIO.seq(ScheduledPipelines() ++= scheduledPipelines)
db.run(insertAction).waitForResult
}
val times = Math.max(1, calculatedTimes)

val actualStart = DateTime.now.withZone(DateTimeZone.UTC).toString(DateTimeFormat)

val pipelineClass = pipelineRecord.`class`
val pipelineName = s"${conf.pipelinePrefix}${actualStart}_${pipelineRecord.id.getOrElse(0)}_${pipelineClass}"

// create the pipeline through cli but do not activiate it
val command = Seq(
"java",
"-cp",
localJar,
pipelineClass,
"create",
"--no-check",
"--start", start.toString(DateTimeFormat),
"--times", times.toString,
"--every", pipelinePeriod,
"--name", pipelineName
) ++ conf.region.toSeq.flatMap(r => Seq("--region", r.getName))

val process = Process(
command,
None,
extraEnvs: _*
)
}

logger.info(s"$logPrefix Executing `${command.mkString(" ")}`")
private def updateNextRunTime(pipelineRecord: Pipeline, options: SchedulerOptions) = {
// update the next runtime in the database
val newNextRunTime = nextRunTime(pipelineRecord.nextRunTime.get, HDuration(pipelineRecord.period), options.scheduledEnd)
val updateQuery = Pipelines().filter(_.id === pipelineRecord.id).map(_.nextRunTime)
logger.debug(s"Update with query ${updateQuery.updateStatement}")
val updateAction = updateQuery.update(Some(newNextRunTime))
db.run(updateAction).waitForResult
}

val outputBuilder = new StringBuilder
val status = process ! ProcessLogger(line => outputBuilder.append(line + "\n"))

(status, outputBuilder.toString, pipelineName)
private def processPipeline(dispatcher: TaskDispatcher, pipeline: Pipeline, options: SchedulerOptions, jar: String, dispatchedPipelines: Counter, failedPipelines: Counter): Unit = {
val timerInst = scheduleTimer.time()
logger.info(s"Dispatching pipleine ${pipeline.name}")

}
dispatcher.dispatch(pipeline, options, jar, conf) match {
case Left(ex) =>
ErrorHandler.pipelineScheduleFailed(pipeline, ex.getMessage())
logger.warn(
s"failed to deploy pipeline ${pipeline.name} in ${TimeUnit.SECONDS.convert(timerInst.stop(), TimeUnit.NANOSECONDS)}"
)
failedPipelines.inc()
case Right(r) =>
logger.info(
s"dispatched pipeline ${pipeline.name} in ${TimeUnit.SECONDS.convert(timerInst.stop(), TimeUnit.NANOSECONDS)}"
)
dispatchedPipelines.inc()
// activation successful - delete the failure counter
db.run(ScheduleFailureCounters().filter(_.pipelineId === pipeline.id.get).delete).waitForResult
}

def activatePipeline(
pipelineRecord: Pipeline,
pipelineName: String,
scheduledStart: DateTime,
actualStart: DateTime,
scheduledEnd: DateTime
) = {

// TODO probably better to just use a different logger
val logPrefix = s"|PipelineId: ${pipelineRecord.id}|"

logger.info(s"$logPrefix Activating pipeline: $pipelineName...")

val awsClientForName = AwsClientForName(AwsClient.getClient(), pipelineName, conf.maxRetry)
val pipelineIdNameMap = awsClientForName.pipelineIdNames

awsClientForName
.forId() match {
case Some(client) =>
val activationStatus = if (client.activatePipelines().nonEmpty) {
"success"
} else {
logger.error(s"$logPrefix Failed to activate pipeline ${client.pipelineIds}")
"fail"
}

logger.info(s"$logPrefix Register pipelines (${client.pipelineIds}) in database.")

val scheduledPipelineRecords = client.pipelineIds.map(awsId =>
ScheduledPipeline(
awsId,
pipelineRecord.id.get,
pipelineIdNameMap(awsId),
scheduledStart,
actualStart,
DateTime.now,
activationStatus,
true
)
)

val insertAction = DBIO.seq(ScheduledPipelines() ++= scheduledPipelineRecords)
db.run(insertAction).waitForResult

logger.info(s"$logPrefix updating the next run time")

// update the next runtime in the database
val newNextRunTime = nextRunTime(pipelineRecord.nextRunTime.get, HDuration(pipelineRecord.period), scheduledEnd)
val updateQuery = Pipelines().filter(_.id === pipelineRecord.id).map(_.nextRunTime)
logger.debug(s"$logPrefix Update with query ${updateQuery.updateStatement}")
val updateAction = updateQuery.update(Some(newNextRunTime))
db.run(updateAction).waitForResult

// activate successful, reset the failure counter, by deleting it
db.run(ScheduleFailureCounters().filter(_.pipelineId === pipelineRecord.id.get).delete).waitForResult

logger.info(s"$logPrefix Successfully scheduled pipeline $pipelineName")
case None =>
val errorMessage = s"pipeline with name $pipelineName not found"
ErrorHandler.pipelineScheduleFailed(pipelineRecord, errorMessage)
}
// update the next run time for this pipeline
updateNextRunTime(pipeline, options)
}

def run(options: SchedulerOptions): Unit = {
Expand All @@ -198,6 +108,24 @@ object StartScheduledPipelines extends StarportActivity {
val actualStart = options.actualStart
db.run(DBIO.seq(SchedulerMetrics() += SchedulerMetric(actualStart))).waitForResult

// in case of default taskDispatcher: the dispatchedPipelines and succesfulPipelines metrics should be exactly same.
val dispatchedPipelines = metrics.register("counter.successful-pipeline-dispatch-count", new Counter())
val successfulPipelines = metrics.register("counter.successful-pipeline-deployment-count", new Counter())
val failedPipelines = metrics.register("counter.failed-pipeline-deployment-count", new Counter())

val taskDispatcher: TaskDispatcher = conf.dispatcherType match {
case "default" => new TaskDispatcherImpl()
case x =>
// pipelines scheduled in a previous run should be fetched here if the dispatcher is remote
throw new NotImplementedError(s"there is no task dispatcher implementation for $x")
}

// fetch the pipelines which may have been scheduled in a previous runs but are not present in the database yet
// this operation is a no-op in case of a local dispatcher like the TaskDispatcherImpl
val previouslyScheduledPipelines = taskDispatcher.retrieve(conf)
successfulPipelines.inc(previouslyScheduledPipelines.length)
updateScheduledPipelines(previouslyScheduledPipelines)

val pipelineModels = pendingPipelineRecords(options.scheduledEnd)
db.run(DBIO.seq(
SchedulerMetrics()
Expand All @@ -208,7 +136,8 @@ object StartScheduledPipelines extends StarportActivity {
.waitForResult
metrics.register("gauges.pipeline_count", new ConstantValueGauge(pipelineModels.size))

val localJars = getLocalJars(pipelineModels)
// TODO: this variable can be moved to the implementation
lazy val localJars = getLocalJars(pipelineModels)

// execute all jars
val parPipelineModels = pipelineModels.par
Expand All @@ -218,35 +147,12 @@ object StartScheduledPipelines extends StarportActivity {
new ForkJoinPool(parallel * Runtime.getRuntime.availableProcessors)
)

val successfulPipelines = metrics.register("counter.successful-pipeline-deployment-count", new Counter())
val failedPipelines = metrics.register("counter.failed-pipeline-deployment-count", new Counter())

parPipelineModels.foreach { p =>

val timerInst = scheduleTimer.time()

logger.info(s"deploying pipleine ${p.name}")
parPipelineModels.foreach(p => processPipeline(taskDispatcher, p, options, localJars(p.jar), dispatchedPipelines, failedPipelines))

val (status, output, pipelineName) = deployPipeline(
p, options.scheduledStart, options.scheduledEnd, localJars(p.jar))

Console.err.print(output)

if (status == 0) { // deploy successfully, perform activation
activatePipeline(p, pipelineName, options.scheduledStart, options.actualStart, options.scheduledEnd)
logger.info(
s"deployed pipeline ${p.name} in ${TimeUnit.SECONDS.convert(timerInst.stop(), TimeUnit.NANOSECONDS)}"
)
successfulPipelines.inc()
} else { // otherwise handle the failure and send notification
ErrorHandler.pipelineScheduleFailed(p, output)
logger.warn(
s"failed to deploy pipeline ${p.name} in ${TimeUnit.SECONDS.convert(timerInst.stop(), TimeUnit.NANOSECONDS)}"
)
failedPipelines.inc()
}

}
// retrieve the scheduled pipelines and save the information in the db
val scheduledPipelines = taskDispatcher.retrieve(conf)
successfulPipelines.inc(scheduledPipelines.length)
updateScheduledPipelines(scheduledPipelines)

db.run(DBIO.seq(
SchedulerMetrics()
Expand All @@ -255,7 +161,6 @@ object StartScheduledPipelines extends StarportActivity {
.update(Option(DateTime.now))
))
.waitForResult

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class StarportSettings(val config: Config) extends Serializable {

val jdbc: JdbcConfig = JdbcConfig(config.getConfig("krux.starport.jdbc"))

val dispatcherType: String = if (config.hasPath("krux.starport.dispatcher.type")) config.getString("krux.starport.dispatcher.type") else "default"

val parallel: Int = config.getInt("krux.starport.parallel")

val maxRetry: Int = config.getInt("hyperion.aws.client.max_retry")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.krux.starport.dispatcher

import com.krux.starport.cli.SchedulerOptions
import com.krux.starport.db.record.{Pipeline, ScheduledPipeline}
import com.krux.starport.exception.StarportException
import com.krux.starport.config.StarportSettings
import com.krux.starport.db.record.ScheduledPipeline

trait TaskDispatcher {

/**
* This method deploys and activates a Pipeline object using hyperion cli.
* It should be implemented in a thread safe manner as dispatch tasks can be executed in parallel
* @param pipeline
* @param options
* @param jar
* @param conf
* @return Unit - This is a side effect function
*/
def dispatch(pipeline: Pipeline, options: SchedulerOptions, jar: String, conf: StarportSettings): Either[StarportException, Unit]

/**
* This method retrieves all the pipeline Ids for the the pipelines deployed via dispatch.
* It is meant to be invoked before starting task dispatching or after all the tasks have been dispatched.
* This operation is not meant to be Threadsafe and should not be executed in parallel.
* @param conf
* @return Seq of all ScheduledPipelines
*/
def retrieve(conf: StarportSettings): Seq[ScheduledPipeline]

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.krux.starport.dispatcher.impl

import java.util.concurrent.ConcurrentLinkedQueue
import collection.JavaConverters._
import com.krux.starport.db.record.ScheduledPipeline

/**
* This object contains few methods to make the java implementation of ConcurrentLinkedQueue nicer to use
*/
object ConcurrentQueueHelpers {

implicit class AugmentedConcurrentQueue(q: ConcurrentLinkedQueue[ScheduledPipeline]) {
/**
* Empties the queue
*/
def consumeAll(): Unit = {
while(!q.isEmpty) {
q.poll()
}
}

/**
* Returns all the elements of the queue in an ordered List and removes all the elements from the queue
* @return
*/
def retrieveAll() = {
val elements = q.iterator().asScala.toList
consumeAll()
elements
}
}

}
Loading

0 comments on commit 3d2fb6b

Please sign in to comment.