Skip to content

Commit

Permalink
work-stealing-test-runner
Browse files Browse the repository at this point in the history
  • Loading branch information
HollandDM committed Feb 23, 2025
1 parent a6992ed commit c09aa01
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 43 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ dist/
build/
*.bak
mill-assembly.jar
mill-native
mill-native
*.mill.orig
*.mill.rej
3 changes: 3 additions & 0 deletions core/api/src/mill/api/Ctx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ object Ctx {
def async[T](dest: os.Path, key: String, message: String)(t: => T)(implicit
ctx: mill.api.Ctx
): Future[T]

def tryDecreaseMaxThreadCount(): Boolean
def tryIncreaseMaxThreadCount(): Boolean
}

trait Impl extends Api with ExecutionContext with AutoCloseable {
Expand Down
21 changes: 21 additions & 0 deletions core/exec/src/mill/exec/ExecutionContexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ private object ExecutionContexts {
ctx: mill.api.Ctx
): Future[T] =
Future.successful(t)

def tryDecreaseMaxThreadCount(): Boolean = false
def tryIncreaseMaxThreadCount(): Boolean = false
}

/**
Expand Down Expand Up @@ -105,5 +108,23 @@ private object ExecutionContexts {
}
}(this)
}

def tryDecreaseMaxThreadCount(): Boolean = synchronized {
if (executor.getActiveCount() < executor.getMaximumPoolSize()) {
updateThreadCount(-1)
true
} else {
false
}
}

def tryIncreaseMaxThreadCount(): Boolean = synchronized {
if (executor.getMaximumPoolSize() < threadCount0) {
updateThreadCount(1)
true
} else {
false
}
}
}
}
41 changes: 38 additions & 3 deletions scalalib/src/mill/scalalib/TestModuleUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package mill.scalalib

import mill.api.{Ctx, PathRef, Result}
import mill.constants.EnvVars
import mill.testrunner.{TestArgs, TestResult, TestRunnerUtils}
import mill.testrunner.{TestArgs, TestResult, TestRunnerUtils, TestMmapCommunicator}
import mill.util.Jvm
import mill.Task
import sbt.testing.Status
Expand All @@ -11,6 +11,7 @@ import java.time.format.DateTimeFormatter
import java.time.temporal.ChronoUnit
import java.time.{Instant, LocalDateTime, ZoneId}
import scala.xml.Elem
import scala.annotation.switch

private[scalalib] object TestModuleUtil {
def runTests(
Expand Down Expand Up @@ -59,16 +60,18 @@ private[scalalib] object TestModuleUtil {

val argsFile = base / "testargs"
val sandbox = base / "sandbox"
val communacateFile = base / "comm.dat"
os.write(argsFile, upickle.default.write(testArgs), createFolders = true)
os.write(communacateFile, Array.fill[Byte](TestMmapCommunicator.BufferSize)(0.toByte), createFolders = false)

os.makeDir.all(sandbox)

Jvm.callProcess(
val subprocess = Jvm.spawnProcess(
mainClass = "mill.testrunner.entrypoint.TestRunnerMain",
classPath = (runClasspath ++ testrunnerEntrypointClasspath).map(_.path),
jvmArgs = jvmArgs,
env = forkEnv ++ resourceEnv,
mainArgs = Seq(testRunnerClasspathArg, argsFile.toString),
mainArgs = Seq(testRunnerClasspathArg, argsFile.toString, communacateFile.toString),
cwd = if (testSandboxWorkingDir) sandbox else forkWorkingDir,
cpPassingJarPath = Option.when(useArgsFile)(
os.temp(prefix = "run-", suffix = ".jar", deleteOnExit = false)
Expand All @@ -78,6 +81,38 @@ private[scalalib] object TestModuleUtil {
stdout = os.Inherit
)

TestMmapCommunicator.using(communacateFile.toString) { communicator =>
val sleepMillis = 100 + scala.util.Random.nextInt(100)
var delegatedThreadCount = 0

while (subprocess.isAlive()) {
(communicator.readSignal(): @switch) match {
case 0 => ()
case 1 =>
if (Task.fork.tryDecreaseMaxThreadCount()) {
delegatedThreadCount += 1
communicator.writeSignal(2)
}
case 2 => ()
case other =>
communicator.writeSignal(0)
}
communicator.readAll().zipWithIndex.foreach { case (i, index) =>
if (i == 1) {
Task.fork.tryIncreaseMaxThreadCount()
delegatedThreadCount -= 1
communicator.writeIndex(index, 2)
}
}
Thread.sleep(sleepMillis)
}

while (delegatedThreadCount > 0) {
Task.fork.tryIncreaseMaxThreadCount()
delegatedThreadCount -= 1
}
}

if (!os.exists(outputPath))
Result.Failure(s"Test reporting Failed: ${outputPath} does not exist")
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ object TestRunnerScalatestTests extends TestSuite {
3,
Map(
// No test grouping is triggered because we only run one test class
testrunner.scalatest -> Set("out.json", "sandbox", "test-report.xml", "testargs"),
testrunnerGrouping.scalatest -> Set("out.json", "sandbox", "test-report.xml", "testargs")
testrunner.scalatest -> Set("out.json", "sandbox", "test-report.xml", "testargs", "comm.dat"),
testrunnerGrouping.scalatest -> Set("out.json", "sandbox", "test-report.xml", "testargs", "comm.dat")
)
)

Expand All @@ -42,7 +42,7 @@ object TestRunnerScalatestTests extends TestSuite {
Seq("*"),
9,
Map(
testrunner.scalatest -> Set("out.json", "sandbox", "test-report.xml", "testargs"),
testrunner.scalatest -> Set("out.json", "sandbox", "test-report.xml", "testargs", "comm.dat"),
testrunnerGrouping.scalatest -> Set(
"group-0-mill.scalalib.ScalaTestSpec",
"mill.scalalib.ScalaTestSpec3",
Expand Down Expand Up @@ -74,7 +74,7 @@ object TestRunnerScalatestTests extends TestSuite {
Seq("*", "--", "-z", "A Set 2"),
3,
Map(
testrunner.scalatest -> Set("out.json", "sandbox", "test-report.xml", "testargs"),
testrunner.scalatest -> Set("out.json", "sandbox", "test-report.xml", "testargs", "comm.dat"),
testrunnerGrouping.scalatest -> Set(
"group-0-mill.scalalib.ScalaTestSpec",
"mill.scalalib.ScalaTestSpec3",
Expand Down
8 changes: 4 additions & 4 deletions scalalib/test/src/mill/scalalib/TestRunnerUtestTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ object TestRunnerUtestTests extends TestSuite {
Seq("mill.scalalib.FooTests"),
1,
Map(
testrunner.utest -> Set("out.json", "sandbox", "test-report.xml", "testargs"),
testrunner.utest -> Set("out.json", "sandbox", "test-report.xml", "testargs", "comm.dat"),
// When there is only one test group with test classes, we do not put it in a subfolder
testrunnerGrouping.utest -> Set("out.json", "sandbox", "test-report.xml", "testargs")
testrunnerGrouping.utest -> Set("out.json", "sandbox", "test-report.xml", "testargs", "comm.dat")
)
)
test("multi") - tester.testOnly(
Seq("*Bar*", "*bar*"),
2,
Map(
testrunner.utest -> Set("out.json", "sandbox", "test-report.xml", "testargs"),
testrunner.utest -> Set("out.json", "sandbox", "test-report.xml", "testargs", "comm.dat"),
// When there are multiple test groups with one test class each, we
// put each test group in a subfolder with the number of the class
testrunnerGrouping.utest -> Set(
Expand All @@ -57,7 +57,7 @@ object TestRunnerUtestTests extends TestSuite {
Seq("*"),
3,
Map(
testrunner.utest -> Set("out.json", "sandbox", "test-report.xml", "testargs"),
testrunner.utest -> Set("out.json", "sandbox", "test-report.xml", "testargs", "comm.dat"),
// When there are multiple test groups some with multiple test classes, we put each
// test group in a subfolder with the index of the group, and for any test groups
// with only one test class we append the name of the class
Expand Down
82 changes: 82 additions & 0 deletions testrunner/src/mill/testrunner/TestMmapCommunicator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package mill.testrunner

import java.nio.channels.FileChannel
import java.io.RandomAccessFile
import scala.util.Try
import scala.util.Using
import java.nio.ByteBuffer

private[mill] sealed abstract class TestMmapCommunicator extends AutoCloseable {

def readSignal(): Int

def writeSignal(signal: Int): Unit

def readIndex(index: Int): Int

def writeIndex(index: Int, value: Int): Unit

def readAll(): Array[Int]

def close(): Unit

}

private[mill] object TestMmapCommunicator {
/**
* first 4 bytes: communicate signal
* next 4092 bytes: thread specific communication slots (4 bytes each)
*
* This mean that we can give back thread at most 1023 times to the parent process.
* Other threads will be given back when test process terminates.
*/
final val BufferSize = 4096

private[testrunner] val emptyCommunicator = new TestMmapCommunicator {
override def readSignal(): Int = -1
override def writeSignal(signal: Int): Unit = ()
override def readIndex(index: Int): Int = -1
override def writeIndex(index: Int, value: Int): Unit = ()
override def readAll(): Array[Int] = Array.empty
override def close(): Unit = ()
}

def using[A](filename: String)(body: TestMmapCommunicator => A): A = {
val communicator = Try {
val file = new RandomAccessFile(filename, "rw")
val channel = file.getChannel
val buffer = channel.map(FileChannel.MapMode.READ_WRITE, 0, BufferSize)

new TestMmapCommunicator {
override def readSignal(): Int = this.synchronized { buffer.getInt(0) }
override def writeSignal(signal: Int): Unit = this.synchronized { buffer.putInt(0, signal) }
override def readIndex(index: Int): Int = if (index < 0 || index >= 1023) {
-1
} else {
this.synchronized { buffer.getInt(4 + index * 4) }
}
override def writeIndex(index: Int, value: Int): Unit = if (index < 0 || index >= 1023) {
()
} else {
this.synchronized { buffer.putInt(4 + index * 4, value) }
}
override def readAll(): Array[Int] = {
val length = (BufferSize >> 2) - 1
val array = new Array[Int](length)
this.synchronized {
for (i <- 0 until length) {
array(i) = buffer.getInt(4 + i * 4)
}
}
array
}
override def close(): Unit = {
channel.close()
file.close()
}
}
}.getOrElse(emptyCommunicator)

Using.resource(communicator)(body)
}
}
20 changes: 12 additions & 8 deletions testrunner/src/mill/testrunner/TestRunnerMain0.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import mill.internal.PrintLogger
def main0(args: Array[String], classLoader: ClassLoader): Unit = {
try {
val testArgs = upickle.default.read[mill.testrunner.TestArgs](os.read(os.Path(args(1))))
val communicatorFile = args.lift(2).getOrElse("")
val ctx = new Ctx.Log with Ctx.Home {
val log = new PrintLogger(
testArgs.colored,
Expand All @@ -27,14 +28,17 @@ import mill.internal.PrintLogger

val filter = TestRunnerUtils.globFilter(testArgs.globSelectors)

val result = TestRunnerUtils.runTestFramework0(
frameworkInstances = Framework.framework(testArgs.framework),
testClassfilePath = Seq.from(testArgs.testCp),
args = testArgs.arguments,
classFilter = cls => filter(cls.getName),
cl = classLoader,
testReporter = DummyTestReporter
)(ctx)
val result = TestMmapCommunicator.using(communicatorFile) { communicator =>
TestRunnerUtils.runTestFramework0(
frameworkInstances = Framework.framework(testArgs.framework),
testClassfilePath = Seq.from(testArgs.testCp),
args = testArgs.arguments,
classFilter = cls => filter(cls.getName),
cl = classLoader,
testReporter = DummyTestReporter,
communicator = communicator
)(ctx)
}

// Clear interrupted state in case some badly-behaved test suite
// dirtied the thread-interrupted flag and forgot to clean up. Otherwise,
Expand Down
Loading

0 comments on commit c09aa01

Please sign in to comment.