Skip to content

Commit 913d48d

Browse files
committed
make sure there is at most one spark context inside the same jvm
1 parent aa43a8d commit 913d48d

File tree

3 files changed

+38
-12
lines changed

3 files changed

+38
-12
lines changed

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,23 @@ import org.scalatest.FunSuite
2121

2222
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
2323
import org.apache.spark.mllib.util.LocalSparkContext
24-
import org.apache.spark.sql.SchemaRDD
24+
import org.apache.spark.sql.{SQLContext, SchemaRDD}
2525

2626
class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
2727

28-
import sqlContext._
28+
@transient var sqlContext: SQLContext = _
29+
@transient var dataset: SchemaRDD = _
2930

30-
val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
31+
override def beforeAll(): Unit = {
32+
super.beforeAll()
33+
sqlContext = new SQLContext(sc)
34+
dataset = sqlContext.createSchemaRDD(
35+
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
36+
}
3137

3238
test("logistic regression") {
39+
val sqlContext = this.sqlContext
40+
import sqlContext._
3341
val lr = new LogisticRegression
3442
val model = lr.fit(dataset)
3543
model.transform(dataset)
@@ -38,6 +46,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
3846
}
3947

4048
test("logistic regression with setters") {
49+
val sqlContext = this.sqlContext
50+
import sqlContext._
4151
val lr = new LogisticRegression()
4252
.setMaxIter(10)
4353
.setRegParam(1.0)
@@ -48,6 +58,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
4858
}
4959

5060
test("logistic regression fit and transform with varargs") {
61+
val sqlContext = this.sqlContext
62+
import sqlContext._
5163
val lr = new LogisticRegression
5264
val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
5365
model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,18 @@ import org.apache.spark.ml.classification.LogisticRegression
2323
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
2424
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
2525
import org.apache.spark.mllib.util.LocalSparkContext
26-
import org.apache.spark.sql.SchemaRDD
26+
import org.apache.spark.sql.{SQLContext, SchemaRDD}
2727

2828
class CrossValidatorSuite extends FunSuite with LocalSparkContext {
2929

30-
import sqlContext._
30+
@transient var dataset: SchemaRDD = _
3131

32-
val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
32+
override def beforeAll(): Unit = {
33+
super.beforeAll()
34+
val sqlContext = new SQLContext(sc)
35+
dataset = sqlContext.createSchemaRDD(
36+
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
37+
}
3338

3439
test("cross validation with logistic regression") {
3540
val lr = new LogisticRegression

mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,26 @@
1717

1818
package org.apache.spark.mllib.util
1919

20-
import org.scalatest.{BeforeAndAfterAll, Suite}
20+
import org.scalatest.Suite
21+
import org.scalatest.BeforeAndAfterAll
2122

22-
import org.apache.spark.SparkContext
23-
import org.apache.spark.sql.SQLContext
23+
import org.apache.spark.{SparkConf, SparkContext}
2424

2525
trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
26-
@transient val sc = new SparkContext("local", "test")
27-
@transient lazy val sqlContext = new SQLContext(sc)
26+
@transient var sc: SparkContext = _
27+
28+
override def beforeAll() {
29+
super.beforeAll()
30+
val conf = new SparkConf()
31+
.setMaster("local[2]")
32+
.setAppName("MLlibUnitTest")
33+
sc = new SparkContext(conf)
34+
}
2835

2936
override def afterAll() {
30-
sc.stop()
37+
if (sc != null) {
38+
sc.stop()
39+
}
3140
super.afterAll()
3241
}
3342
}

0 commit comments

Comments
 (0)