Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions client-spark/common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,96 @@
<groupId>net.jpountz.lz4</groupId>
<artifactId>lz4</artifactId>
</dependency>

<dependency>
<groupId>org.apache.fory</groupId>
<artifactId>fory-core</artifactId>
<version>0.12.0</version>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also introduce fory-scala dependency: https://mvnrepository.com/artifact/org.apache.fory/fory-scala

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pity to say that Spark still uses the scala2.x

</dependency>

<!-- Scala dependencies -->
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
<scope>provided</scope>
</dependency>

<!-- Test dependencies -->
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>3.2.15</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatestplus</groupId>
<artifactId>junit-4-13_${scala.binary.version}</artifactId>
<version>3.2.15.0</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>${scala.maven.plugin.version}</version>
<executions>
<execution>
<id>scala-compile-first</id>
<phase>process-resources</phase>
<goals>
<goal>add-source</goal>
<goal>compile</goal>
</goals>
</execution>
<execution>
<id>scala-test-compile-first</id>
<phase>process-test-resources</phase>
<goals>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
<configuration>
<scalaVersion>${scala.version}</scalaVersion>
</configuration>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<executions>
<execution>
<phase>compile</phase>
<goals>
<goal>compile</goal>
</goals>
</execution>
</executions>
</plugin>

<!-- ScalaTest Maven plugin for running tests -->
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
<version>2.0.2</version>
<configuration>
<reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
<junitxml>.</junitxml>
<filereports>TestSuite.txt</filereports>
</configuration>
<executions>
<execution>
<id>test</id>
<goals>
<goal>test</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.uniffle.common.config.ConfigUtils;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.shuffle.ShuffleSerializer;

public class RssSparkConfig {

Expand Down Expand Up @@ -110,6 +111,12 @@ public class RssSparkConfig {
.defaultValue(true)
.withDescription("indicates row based shuffle, set false when use in columnar shuffle");

public static final ConfigOption<ShuffleSerializer> RSS_SHUFFLE_SERIALIZER =
ConfigOptions.key("rss.client.shuffle.serializer")
.enumType(ShuffleSerializer.class)
.noDefaultValue()
.withDescription("Shuffle serializer type");

public static final ConfigOption<Boolean> RSS_MEMORY_SPILL_ENABLED =
ConfigOptions.key("rss.client.memory.spill.enabled")
.booleanType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.ForySerializerInstance;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.RssSparkConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -58,6 +58,7 @@
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;

import static org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_SERIALIZER;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED;

public class WriteBufferManager extends MemoryConsumer {
Expand All @@ -81,7 +82,6 @@ public class WriteBufferManager extends MemoryConsumer {
private int shuffleId;
private String taskId;
private long taskAttemptId;
private SerializerInstance instance;
private ShuffleWriteMetrics shuffleWriteMetrics;
// cache partition -> records
private Map<Integer, WriterBuffer> buffers;
Expand Down Expand Up @@ -192,8 +192,11 @@ public WriteBufferManager(
// in columnar shuffle, the serializer here is never used
this.isRowBased = rssConf.getBoolean(RssSparkConfig.RSS_ROW_BASED);
if (isRowBased) {
this.instance = serializer.newInstance();
this.serializeStream = instance.serializeStream(arrayOutputStream);
if (rssConf.contains(RSS_SHUFFLE_SERIALIZER)) {
this.serializeStream = new ForySerializerInstance().serializeStream(arrayOutputStream);
} else {
this.serializeStream = serializer.newInstance().serializeStream(arrayOutputStream);
}
}
boolean compress =
rssConf.getBoolean(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.uniffle.shuffle;

public enum ShuffleSerializer {
FORY
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer

import org.apache.fory.config.{CompatibleMode, Language}
import org.apache.fory.io.ForyInputStream
import org.apache.fory.{Fory, ThreadLocalFory}
import org.apache.spark.internal.Logging

import java.io.{InputStream, OutputStream, Serializable}
import java.nio.ByteBuffer
import scala.reflect.ClassTag

@SerialVersionUID(1L)
class ForySerializer extends org.apache.spark.serializer.Serializer
with Logging
with Serializable {

override def newInstance(): SerializerInstance = new ForySerializerInstance()

override def supportsRelocationOfSerializedObjects: Boolean = true

}

class ForySerializerInstance extends org.apache.spark.serializer.SerializerInstance {

private val fury = Fory.builder()
.withLanguage(Language.JAVA)
.withRefTracking(true)
.withCompatibleMode(CompatibleMode.SCHEMA_CONSISTENT)
.requireClassRegistration(false)
.buildThreadLocalFory()

override def serialize[T: ClassTag](t: T): ByteBuffer = {
val bytes = fury.serialize(t.asInstanceOf[AnyRef])
ByteBuffer.wrap(bytes)
}

override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
fury.deserialize(bytes).asInstanceOf[T]
}

override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
// Fury handles class loading internally, so we can use the standard deserialize method
deserialize[T](bytes)
}

override def serializeStream(s: OutputStream): SerializationStream = {
new ForySerializationStream(fury, s)
}

override def deserializeStream(s: InputStream): DeserializationStream = {
new ForyDeserializationStream(fury, s)
}
}

class ForySerializationStream(fury: ThreadLocalFory, outputStream: OutputStream)
extends org.apache.spark.serializer.SerializationStream {

private var closed = false

override def writeObject[T: ClassTag](t: T): SerializationStream = {
if (closed) {
throw new IllegalStateException("Stream is closed")
}
fury.serialize(outputStream, t)
this
}

override def flush(): Unit = {
if (!closed) {
outputStream.flush()
}
}

override def close(): Unit = {
if (!closed) {
try {
outputStream.close()
} finally {
closed = true
}
}
}
}

class ForyDeserializationStream(fury: ThreadLocalFory, inputStream: InputStream)
extends org.apache.spark.serializer.DeserializationStream {

private var closed = false
private val foryStream = new ForyInputStream(inputStream)

override def readObject[T: ClassTag](): T = {
if (closed) {
throw new IllegalStateException("Stream is closed")
}
fury.deserialize(foryStream).asInstanceOf[T]
}

override def close(): Unit = {
if (!closed) {
try {
foryStream.close()
} finally {
closed = true
}
}
}
}
Loading
Loading