Skip to content

Commit bb655de

Browse files
committed
Shufflevault: Shuffle on S3
1 parent 4eb56bb commit bb655de

36 files changed

+948
-46
lines changed

core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ final class BypassMergeSortShuffleWriter<K, V>
9898
private final long mapId;
9999
private final Serializer serializer;
100100
private final ShuffleExecutorComponents shuffleExecutorComponents;
101+
private final boolean remoteWrites;
101102

102103
/** Array of file writers, one for each partition */
103104
private DiskBlockObjectWriter[] partitionWriters;
@@ -136,6 +137,7 @@ final class BypassMergeSortShuffleWriter<K, V>
136137
this.mapId = mapId;
137138
this.shuffleId = dep.shuffleId();
138139
this.partitioner = dep.partitioner();
140+
this.remoteWrites = dep.useRemoteShuffleStorage();
139141
this.numPartitions = partitioner.numPartitions();
140142
this.writeMetrics = writeMetrics;
141143
this.serializer = dep.serializer();
@@ -149,12 +151,14 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
149151
assert (partitionWriters == null);
150152
ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents
151153
.createMapOutputWriter(shuffleId, mapId, numPartitions);
154+
BlockManagerId blockManagerId = remoteWrites ?
155+
RemoteShuffleStorage.BLOCK_MANAGER_ID() : blockManager.shuffleServerId();
152156
try {
153157
if (!records.hasNext()) {
154158
partitionLengths = mapOutputWriter.commitAllPartitions(
155159
ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
156160
mapStatus = MapStatus$.MODULE$.apply(
157-
blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue());
161+
blockManagerId, partitionLengths, mapId, getAggregatedChecksumValue());
158162
return;
159163
}
160164
final SerializerInstance serInstance = serializer.newInstance();
@@ -196,7 +200,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
196200

197201
partitionLengths = writePartitionedData(mapOutputWriter);
198202
mapStatus = MapStatus$.MODULE$.apply(
199-
blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue());
203+
blockManagerId, partitionLengths, mapId, getAggregatedChecksumValue());
200204
} catch (Exception e) {
201205
try {
202206
mapOutputWriter.abort(e);
@@ -236,8 +240,10 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
236240
try {
237241
for (int i = 0; i < numPartitions; i++) {
238242
final File file = partitionWriterSegments[i].file();
239-
ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i);
240243
if (file.exists()) {
244+
// TODO: Remove thsi comment: the line below was moved so that assertions
245+
// cann be added and in general safe
246+
ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i);
241247
if (transferToEnabled) {
242248
// Using WritableByteChannelWrapper to make resource closing consistent between
243249
// this implementation and UnsafeShuffleWriter.

core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import java.nio.channels.WritableByteChannel;
2727
import java.util.Iterator;
2828

29+
import org.apache.spark.storage.BlockManagerId;
30+
import org.apache.spark.storage.RemoteShuffleStorage;
2931
import scala.Option;
3032
import scala.Product2;
3133
import scala.jdk.javaapi.CollectionConverters;
@@ -89,6 +91,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
8991
private final boolean transferToEnabled;
9092
private final int initialSortBufferSize;
9193
private final int mergeBufferSizeInBytes;
94+
private final boolean remoteWrites;
9295

9396
@Nullable private MapStatus mapStatus;
9497
@Nullable private ShuffleExternalSorter sorter;
@@ -135,6 +138,7 @@ public UnsafeShuffleWriter(
135138
this.shuffleId = dep.shuffleId();
136139
this.serializer = dep.serializer().newInstance();
137140
this.partitioner = dep.partitioner();
141+
this.remoteWrites = dep.useRemoteShuffleStorage();
138142
this.writeMetrics = writeMetrics;
139143
this.shuffleExecutorComponents = shuffleExecutorComponents;
140144
this.taskContext = taskContext;
@@ -247,8 +251,10 @@ void closeAndWriteOutput() throws IOException {
247251
}
248252
}
249253
}
254+
BlockManagerId blockManagerId = remoteWrites ?
255+
RemoteShuffleStorage.BLOCK_MANAGER_ID() : blockManager.shuffleServerId();
250256
mapStatus = MapStatus$.MODULE$.apply(
251-
blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue());
257+
blockManagerId, partitionLengths, mapId, getAggregatedChecksumValue());
252258
}
253259

254260
@VisibleForTesting
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.storage;
19+
20+
import java.io.IOException;
21+
import java.io.InputStream;
22+
import java.nio.ByteBuffer;
23+
24+
import org.apache.hadoop.conf.Configuration;
25+
import org.apache.hadoop.fs.FileSystem;
26+
import org.apache.hadoop.fs.Path;
27+
28+
import org.apache.spark.network.buffer.ManagedBuffer;
29+
30+
/**
31+
* A {@link ManagedBuffer} backed by a file using Hadoop FileSystem.
32+
* This buffer creates an input stream with a 64MB buffer size for efficient reading.
33+
* Note: This implementation throws UnsupportedOperationException for methods that
34+
* require loading the entire file into memory (nioByteBuffer, convertToNetty, convertToNettyForSsl)
35+
* as files can be very large and loading them entirely into memory is not practical.
36+
*/
37+
public class FileSystemManagedBuffer extends ManagedBuffer {
38+
private int bufferSize; // 64MB buffer size
39+
private final Path filePath;
40+
private final long fileSize;
41+
private final Configuration hadoopConf;
42+
43+
public FileSystemManagedBuffer(Path filePath, Configuration hadoopConf) throws IOException {
44+
this.filePath = filePath;
45+
this.hadoopConf = hadoopConf;
46+
// Get file size using FileSystem.newInstance to avoid cached dependencies
47+
FileSystem fileSystem = FileSystem.newInstance(filePath.toUri(), hadoopConf);
48+
try {
49+
this.fileSize = fileSystem.getFileStatus(filePath).getLen();
50+
} finally {
51+
fileSystem.close();
52+
}
53+
bufferSize = 64;
54+
}
55+
56+
public FileSystemManagedBuffer(Path filePath, Configuration hadoopConf, int bufferSize)
57+
throws IOException {
58+
this(filePath, hadoopConf);
59+
this.bufferSize = bufferSize;
60+
}
61+
62+
@Override
63+
public long size() {
64+
return fileSize;
65+
}
66+
67+
@Override
68+
public ByteBuffer nioByteBuffer() throws IOException {
69+
throw new UnsupportedOperationException(
70+
"FileSystemManagedBuffer does not support nioByteBuffer() as it would require loading " +
71+
"the entire file into memory, which is not practical for large files. " +
72+
"Use createInputStream() instead.");
73+
}
74+
75+
@Override
76+
public InputStream createInputStream() throws IOException {
77+
// Create a new FileSystem instance to avoid cached dependencies
78+
// and create a buffered input stream with 64MB buffer size for efficient reading
79+
FileSystem fileSystem = FileSystem.newInstance(filePath.toUri(), hadoopConf);
80+
return fileSystem.open(filePath, bufferSize * 1024 * 1024);
81+
}
82+
83+
@Override
84+
public ManagedBuffer retain() {
85+
// FileSystemManagedBuffer doesn't use reference counting, so just return this
86+
return this;
87+
}
88+
89+
@Override
90+
public ManagedBuffer release() {
91+
// FileSystemManagedBuffer doesn't use reference counting, so just return this
92+
return this;
93+
}
94+
95+
@Override
96+
public Object convertToNetty() {
97+
throw new UnsupportedOperationException(
98+
"FileSystemManagedBuffer does not support convertToNetty() as it would require loading " +
99+
"the entire file into memory, which is not practical for large files. " +
100+
"Use createInputStream() instead.");
101+
}
102+
103+
@Override
104+
public Object convertToNettyForSsl() {
105+
throw new UnsupportedOperationException(
106+
"FileSystemManagedBuffer does not support convertToNettyForSsl()" +
107+
" as it would require loading " +
108+
"the entire file into memory, which is not practical for large files. " +
109+
"Use createInputStream() instead.");
110+
}
111+
112+
@Override
113+
public String toString() {
114+
return "FileSegmentManagedBuffer[file=" + filePath + ",length=" + fileSize + "]";
115+
}
116+
}

core/src/main/scala/org/apache/spark/Dependency.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
9090
val mapSideCombine: Boolean = false,
9191
val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor,
9292
val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS,
93-
val checksumMismatchFullRetryEnabled: Boolean = false)
93+
val checksumMismatchFullRetryEnabled: Boolean = false,
94+
val useRemoteShuffleStorage: Boolean = false
95+
)
9496
extends Dependency[Product2[K, V]] with Logging {
9597

9698
def this(
@@ -249,7 +251,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
249251
)
250252
}
251253

252-
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
254+
if (!useRemoteShuffleStorage) {
255+
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
256+
}
253257
_rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
254258
}
255259

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,8 @@ class SparkContext(config: SparkConf) extends Logging {
650650
_env.blockManager.initialize(_applicationId)
651651
FallbackStorage.registerBlockManagerIfNeeded(
652652
_env.blockManager.master, _conf, _hadoopConfiguration)
653+
RemoteShuffleStorage.registerBlockManagerifNeeded(_env.blockManager.master, _conf,
654+
_hadoopConfiguration)
653655

654656
// The metrics system for Driver need to be set spark.app.id to app ID.
655657
// So it should start after we get app ID from the task scheduler and set spark.app.id.
@@ -2377,6 +2379,11 @@ class SparkContext(config: SparkConf) extends Logging {
23772379
Utils.tryLogNonFatalError {
23782380
FallbackStorage.cleanUp(_conf, _hadoopConfiguration)
23792381
}
2382+
2383+
Utils.tryLogNonFatalError {
2384+
RemoteShuffleStorage.cleanUp(_conf, _hadoopConfiguration)
2385+
}
2386+
23802387
Utils.tryLogNonFatalError {
23812388
_eventLogger.foreach(_.stop())
23822389
}

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,13 @@ package object config {
626626
.checkValue(_.endsWith(java.io.File.separator), "Path should end with separator.")
627627
.createOptional
628628

629+
private[spark] val SHUFFLE_REMOTE_STORAGE_CLEANUP =
630+
ConfigBuilder("spark.shuffle.remote.storage.cleanUp")
631+
.doc("If true, Spark cleans up its fallback storage data during shutting down.")
632+
.version("3.2.0")
633+
.booleanConf
634+
.createWithDefault(false)
635+
629636
private[spark] val STORAGE_DECOMMISSION_SHUFFLE_MAX_DISK_SIZE =
630637
ConfigBuilder("spark.storage.decommission.shuffleBlocks.maxDiskSize")
631638
.doc("Maximum disk space to use to store shuffle blocks before rejecting remote " +
@@ -2905,4 +2912,32 @@ package object config {
29052912
.checkValue(v => v.forall(Set("stdout", "stderr").contains),
29062913
"The value only can be one or more of 'stdout, stderr'.")
29072914
.createWithDefault(Seq("stdout", "stderr"))
2915+
2916+
private[spark] val SHUFFLE_REMOTE_STORAGE_PATH =
2917+
ConfigBuilder("spark.shuffle.remote.storage.path")
2918+
.doc("The location for storing shuffle blocks on remote storage.")
2919+
.version("4.1.0")
2920+
.stringConf
2921+
.checkValue(_.endsWith(java.io.File.separator), "Path should end with separator.")
2922+
.createOptional
2923+
2924+
private[spark] val REMOTE_SHUFFLE_BUFFER_SIZE =
2925+
ConfigBuilder("spark.shuffle.remote.buffer.size")
2926+
.version("4.1.0")
2927+
.stringConf
2928+
.createWithDefault("64M")
2929+
2930+
private[spark] val START_REDUCERS_IN_PARALLEL_TO_MAPPER =
2931+
ConfigBuilder("spark.shuffle.consolidation.enabled")
2932+
.doc("starts reducers in parallel to mappers")
2933+
.version("4.1.0")
2934+
.booleanConf
2935+
.createWithDefault(false)
2936+
2937+
private[spark] val EAGERNESS_THRESHOLD_PERCENTAGE =
2938+
ConfigBuilder("spark.shuffle.remote.eagerness.percentage")
2939+
.doc("Percentage of mapper complet tasks before starting reducers ")
2940+
.version("4.1.0")
2941+
.intConf
2942+
.createWithDefault(20)
29082943
}

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1755,7 +1755,12 @@ private[spark] class DAGScheduler(
17551755
log"${MDC(STAGE, stage)} (${MDC(RDD_ID, stage.rdd)}) (first 15 tasks are " +
17561756
log"for partitions ${MDC(PARTITION_IDS, tasks.take(15).map(_.partitionId))})")
17571757
val shuffleId = stage match {
1758-
case s: ShuffleMapStage => Some(s.shuffleDep.shuffleId)
1758+
case s: ShuffleMapStage =>
1759+
// hack to prioritize remote shuffle writes
1760+
if (properties != null) {
1761+
properties.setProperty("remote", s.shuffleDep.useRemoteShuffleStorage.toString)
1762+
}
1763+
Some(s.shuffleDep.shuffleId)
17591764
case _: ResultStage => None
17601765
}
17611766

core/src/main/scala/org/apache/spark/scheduler/Pool.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,11 @@ private[spark] class Pool(
5353
new FairSchedulingAlgorithm()
5454
case SchedulingMode.FIFO =>
5555
new FIFOSchedulingAlgorithm()
56+
case SchedulingMode.WEIGHTED_FIFO =>
57+
new WeightedFIFOSchedulingAlgorithm()
5658
case _ =>
57-
val msg = s"Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead."
59+
val msg = s"Unsupported scheduling mode: $schedulingMode. Use FAIR, FIFO," +
60+
s" or WEIGHTED_FIFO instead."
5861
throw new IllegalArgumentException(msg)
5962
}
6063
}

core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm {
4040
}
4141
}
4242

43+
private[spark] class WeightedFIFOSchedulingAlgorithm extends SchedulingAlgorithm {
44+
override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {
45+
val priority1 = s1.priority
46+
val priority2 = s2.priority
47+
var res = math.signum(priority1 - priority2)
48+
if (res == 0) {
49+
if (s1.weight == s2.weight) {
50+
val stageId1 = s1.stageId
51+
val stageId2 = s2.stageId
52+
res = math.signum(stageId1 - stageId2)
53+
} else {
54+
// Higher the weight, earlier should it run(unlike priority)
55+
res = math.signum(s2.weight - s1.weight)
56+
}
57+
}
58+
res < 0
59+
}
60+
}
61+
4362
private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm {
4463
override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {
4564
val minShare1 = s1.minShare

core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ package org.apache.spark.scheduler
2020
/**
2121
* "FAIR" and "FIFO" determines which policy is used
2222
* to order tasks amongst a Schedulable's sub-queues
23+
* "WEIGHTED_FIFO" is similar to FIFO but uses weight-based comparison in addition.
2324
* "NONE" is used when the a Schedulable has no sub-queues.
2425
*/
2526
object SchedulingMode extends Enumeration {
2627

2728
type SchedulingMode = Value
28-
val FAIR, FIFO, NONE = Value
29+
val FAIR, FIFO, WEIGHTED_FIFO, NONE = Value
2930
}

0 commit comments

Comments
 (0)