Skip to content

Commit 376de8a

Browse files
ankurdavecloud-fan
authored andcommitted
[SPARK-45616][CORE] Avoid ParVector, which does not propagate ThreadLocals or SparkSession
### What changes were proposed in this pull request? `CastSuiteBase` and `ExpressionInfoSuite` use `ParVector.foreach()` to run Spark SQL queries in parallel. They incorrectly assume that each parallel operation will inherit the main thread’s active SparkSession. This is only true when these parallel operations run in freshly-created threads. However, when other code has already run some parallel operations before Spark was started, then there may be existing threads that do not have an active SparkSession. In that case, these tests fail with NullPointerExceptions when creating SparkPlans or running SQL queries. The fix is to use the existing method `ThreadUtils.parmap()`. This method creates fresh threads that inherit the current active SparkSession, and it propagates the Spark ThreadLocals. This PR also adds a scalastyle warning against use of ParVector. ### Why are the changes needed? This change makes `CastSuiteBase` and `ExpressionInfoSuite` less brittle to future changes that may run parallel operations during test startup. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Reproduced the test failures by running a ParVector operation before Spark starts. Verified that this PR fixes the test failures in this condition. ```scala protected override def beforeAll(): Unit = { // Run a ParVector operation before initializing the SparkSession. This starts some Scala // execution context threads that have no active SparkSession. These threads will be reused for // later ParVector operations, reproducing SPARK-45616. new ParVector((0 until 100).toVector).foreach { _ => } super.beforeAll() } ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43466 from ankurdave/SPARK-45616. Authored-by: Ankur Dave <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent e3ba9cf commit 376de8a

File tree

8 files changed

+38
-8
lines changed

8 files changed

+38
-8
lines changed

core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,10 @@ class UnionRDD[T: ClassTag](
7676

7777
override def getPartitions: Array[Partition] = {
7878
val parRDDs = if (isPartitionListingParallel) {
79+
// scalastyle:off parvector
7980
val parArray = new ParVector(rdds.toVector)
8081
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
82+
// scalastyle:on parvector
8183
parArray
8284
} else {
8385
rdds

core/src/main/scala/org/apache/spark/util/ThreadUtils.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,10 @@ private[spark] object ThreadUtils {
363363
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
364364
* at any time. This is useful on canceling of task execution, for example.
365365
*
366+
* Functions are guaranteed to be executed in freshly-created threads that inherit the calling
367+
* thread's Spark thread-local variables. These threads also inherit the calling thread's active
368+
* SparkSession.
369+
*
366370
* @param in - the input collection which should be transformed in parallel.
367371
* @param prefix - the prefix assigned to the underlying thread pool.
368372
* @param maxThreads - maximum number of thread can be created during execution.

scalastyle-config.xml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,18 @@ This file is divided into 3 sections:
227227
]]></customMessage>
228228
</check>
229229

230+
<check customId="parvector" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
231+
<parameters><parameter name="regex">new.*ParVector</parameter></parameters>
232+
<customMessage><![CDATA[
233+
Are you sure you want to create a ParVector? It will not automatically propagate Spark ThreadLocals or the
234+
active SparkSession for the submitted tasks. In most cases, you should use ThreadUtils.parmap instead.
235+
If you must use ParVector, then wrap your creation of the ParVector with
236+
// scalastyle:off parvector
237+
...ParVector...
238+
// scalastyle:on parvector
239+
]]></customMessage>
240+
</check>
241+
230242
<check customId="caselocale" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
231243
<parameters><parameter name="regex">(\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\)))</parameter></parameters>
232244
<customMessage><![CDATA[

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ import java.time.{Duration, LocalDate, LocalDateTime, Period}
2222
import java.time.temporal.ChronoUnit
2323
import java.util.{Calendar, Locale, TimeZone}
2424

25-
import scala.collection.parallel.immutable.ParVector
26-
2725
import org.apache.spark.SparkFunSuite
2826
import org.apache.spark.sql.Row
2927
import org.apache.spark.sql.catalyst.InternalRow
@@ -42,6 +40,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND
4240
import org.apache.spark.sql.types.UpCastRule.numericPrecedence
4341
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
4442
import org.apache.spark.unsafe.types.UTF8String
43+
import org.apache.spark.util.ThreadUtils
4544

4645
/**
4746
* Common test suite for [[Cast]] with ansi mode on and off. It only includes test cases that work
@@ -126,7 +125,11 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
126125
}
127126

128127
test("cast string to timestamp") {
129-
new ParVector(ALL_TIMEZONES.toVector).foreach { zid =>
128+
ThreadUtils.parmap(
129+
ALL_TIMEZONES,
130+
prefix = "CastSuiteBase-cast-string-to-timestamp",
131+
maxThreads = Runtime.getRuntime.availableProcessors
132+
) { zid =>
130133
def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit = {
131134
checkEvaluation(cast(Literal(str), TimestampType, Option(zid.getId)), expected)
132135
}

sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,8 +759,10 @@ case class RepairTableCommand(
759759
val statusPar: Seq[FileStatus] =
760760
if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) {
761761
// parallelize the list of partitions here, then we can have better parallelism later.
762+
// scalastyle:off parvector
762763
val parArray = new ParVector(statuses.toVector)
763764
parArray.tasksupport = evalTaskSupport
765+
// scalastyle:on parvector
764766
parArray.seq
765767
} else {
766768
statuses

sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,14 @@
1717

1818
package org.apache.spark.sql.expressions
1919

20-
import scala.collection.parallel.immutable.ParVector
21-
2220
import org.apache.spark.SparkFunSuite
2321
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
2422
import org.apache.spark.sql.catalyst.expressions._
2523
import org.apache.spark.sql.execution.HiveResult.hiveResultString
2624
import org.apache.spark.sql.internal.SQLConf
2725
import org.apache.spark.sql.test.SharedSparkSession
2826
import org.apache.spark.tags.SlowSQLTest
29-
import org.apache.spark.util.Utils
27+
import org.apache.spark.util.{ThreadUtils, Utils}
3028

3129
@SlowSQLTest
3230
class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
@@ -201,8 +199,11 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
201199
// The encrypt expression includes a random initialization vector to its encrypted result
202200
classOf[AesEncrypt].getName)
203201

204-
val parFuncs = new ParVector(spark.sessionState.functionRegistry.listFunction().toVector)
205-
parFuncs.foreach { funcId =>
202+
ThreadUtils.parmap(
203+
spark.sessionState.functionRegistry.listFunction(),
204+
prefix = "ExpressionInfoSuite-check-outputs-of-expression-examples",
205+
maxThreads = Runtime.getRuntime.availableProcessors
206+
) { funcId =>
206207
// Examples can change settings. We clone the session to prevent tests clashing.
207208
val clonedSpark = spark.cloneSession()
208209
// Coalescing partitions can change result order, so disable it.

streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
5252
outputStreams.foreach(_.validateAtStart())
5353
numReceivers = inputStreams.count(_.isInstanceOf[ReceiverInputDStream[_]])
5454
inputStreamNameAndID = inputStreams.map(is => (is.name, is.id)).toSeq
55+
// scalastyle:off parvector
5556
new ParVector(inputStreams.toVector).foreach(_.start())
57+
// scalastyle:on parvector
5658
}
5759
}
5860

@@ -62,7 +64,9 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
6264

6365
def stop(): Unit = {
6466
this.synchronized {
67+
// scalastyle:off parvector
6568
new ParVector(inputStreams.toVector).foreach(_.stop())
69+
// scalastyle:on parvector
6670
}
6771
}
6872

streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,10 @@ private[streaming] object FileBasedWriteAheadLog {
314314
val groupSize = taskSupport.parallelismLevel.max(8)
315315

316316
source.grouped(groupSize).flatMap { group =>
317+
// scalastyle:off parvector
317318
val parallelCollection = new ParVector(group.toVector)
318319
parallelCollection.tasksupport = taskSupport
320+
// scalastyle:on parvector
319321
parallelCollection.map(handler)
320322
}.flatten
321323
}

0 commit comments

Comments
 (0)