diff --git a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala index 2efd1814bc90..6d36ca51db90 100644 --- a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala +++ b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala @@ -80,30 +80,27 @@ trait SharedSparkContext extends FunSuite with BeforeAndAfterEach with BeforeAnd System.getProperty("user.dir") } - private def getJarFilePath(root: String): String = { - for (platform <- List("linux-x86_64-cpu", "linux-x86_64-gpu", "osx-x86_64-cpu")) { - val jarFiles = new File(s"$root/$platform/target/").listFiles(new FileFilter { - override def accept(pathname: File) = { - pathname.getAbsolutePath.endsWith(".jar") && - !pathname.getAbsolutePath.contains("javadoc") && - !pathname.getAbsolutePath.contains("sources") - } - }) - if (jarFiles != null && jarFiles.nonEmpty) { - return jarFiles.head.getAbsolutePath + private def findJars(root: String): Array[File] = { + val excludedSuffixes = List("bundle", "src", "javadoc", "sources") + new File(root).listFiles(new FileFilter { + override def accept(pathname: File) = { + pathname.getAbsolutePath.endsWith(".jar") && + excludedSuffixes.forall(!pathname.getAbsolutePath.contains(_)) } + }) + } + + private def getJarFilePath(root: String): String = { + val jarFiles = findJars(s"$root/target/") + if (jarFiles != null && jarFiles.nonEmpty) { + jarFiles.head.getAbsolutePath + } else { + null } - null } private def getSparkJar: String = { - val jarFiles = new File(s"$composeWorkingDirPath/target/").listFiles(new FileFilter { - override def accept(pathname: File) = { - pathname.getAbsolutePath.endsWith(".jar") && - !pathname.getAbsolutePath.contains("javadoc") && - !pathname.getAbsolutePath.contains("sources") - } - }) + val jarFiles = findJars(s"$composeWorkingDirPath/target/") if (jarFiles != null && jarFiles.nonEmpty) { jarFiles.head.getAbsolutePath } else { @@ -111,6 +108,9 @@ trait SharedSparkContext extends FunSuite with BeforeAndAfterEach with BeforeAnd } } + private def getNativeJars(root: String): String = + new File(root).listFiles().map(_.toPath).mkString(",") + protected def buildLeNet(): MXNet = { val workingDir = composeWorkingDirPath val assemblyRoot = s"$workingDir/../assembly" @@ -130,6 +130,8 @@ trait SharedSparkContext extends FunSuite with BeforeAndAfterEach with BeforeAnd protected def buildMlp(): MXNet = { val workingDir = composeWorkingDirPath val assemblyRoot = s"$workingDir/../assembly" + val nativeRoot = s"$workingDir/../native/target/lib" + new MXNet() .setBatchSize(128) .setLabelName("softmax_label") @@ -139,7 +141,7 @@ trait SharedSparkContext extends FunSuite with BeforeAndAfterEach with BeforeAnd .setNumEpoch(10) .setNumServer(1) .setNumWorker(numWorkers) - .setExecutorJars(s"${getJarFilePath(assemblyRoot)},$getSparkJar") + .setExecutorJars(s"${getJarFilePath(assemblyRoot)},$getSparkJar,${getNativeJars(nativeRoot)}") .setJava("java") .setTimeout(0) }