Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,32 +42,83 @@ class OneHotEncoderOp extends SimpleSparkOp[OneHotEncoderModel] {
override def store(model: Model, obj: OneHotEncoderModel)
(implicit context: BundleContext[SparkBundleContext]): Model = {
assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz))

assert(!(obj.isSet(obj.inputCol) && obj.isSet(obj.inputCols)), "OneHotEncoderModel cannot have both inputCol and inputCols set")
assert(!(obj.isSet(obj.outputCol) && obj.isSet(obj.outputCols)), "OneHotEncoderModel cannot have both outputCol and outputCols set")
val inputCols = if (obj.isSet(obj.inputCol)) Array(obj.getInputCol) else obj.getInputCols
val df = context.context.dataset.get
val categorySizes = obj.getInputCols.map { f ⇒ OneHotEncoderOp.sizeForField(df.schema(f)) }

model.withValue("category_sizes", Value.intList(categorySizes))
val categorySizes = inputCols.map { f ⇒ OneHotEncoderOp.sizeForField(df.schema(f)) }
var m = model.withValue("category_sizes", Value.intList(categorySizes))
.withValue("drop_last", Value.boolean(obj.getDropLast))
.withValue("handle_invalid", Value.string(obj.getHandleInvalid))

if (obj.isSet(obj.inputCol)) {
m = m.withValue("inputCol", Value.string(obj.getInputCol))
}
if (obj.isSet(obj.inputCols)) {
m = m.withValue("inputCols", Value.stringList(obj.getInputCols))
}
if (obj.isSet(obj.outputCol)) {
m = m.withValue("outputCol", Value.string(obj.getOutputCol))
}
if (obj.isSet(obj.outputCols)) {
m = m.withValue("outputCols", Value.stringList(obj.getOutputCols))
}
m
}

override def load(model: Model)
(implicit context: BundleContext[SparkBundleContext]): OneHotEncoderModel = {
new OneHotEncoderModel(uid = "", categorySizes = model.value("category_sizes").getIntList.toArray)
.setDropLast(model.value("drop_last").getBoolean)
.setHandleInvalid(model.value("handle_invalid").getString)
val m = new OneHotEncoderModel(uid = "", categorySizes = model.value("category_sizes").getIntList.toArray)
.setDropLast(model.value("drop_last").getBoolean)
.setHandleInvalid(model.value("handle_invalid").getString)
if (model.getValue("inputCol").isDefined) {
m.setInputCol(model.value("inputCol").getString)
}
if (model.getValue("inputCols").isDefined) {
m.setInputCols(model.value("inputCols").getStringList.toArray)
}
if (model.getValue("outputCol").isDefined) {
m.setOutputCol(model.value("outputCol").getString)
}
if (model.getValue("outputCols").isDefined) {
m.setOutputCols(model.value("outputCols").getStringList.toArray)
}
m
}
}

override def sparkLoad(uid: String, shape: NodeShape, model: OneHotEncoderModel): OneHotEncoderModel = {
new OneHotEncoderModel(uid = uid, categorySizes = model.categorySizes)
val m = new OneHotEncoderModel(uid = uid, categorySizes = model.categorySizes)
.setDropLast(model.getDropLast)
.setHandleInvalid(model.getHandleInvalid)
if (model.isSet(model.inputCol)) {
m.setInputCol(model.getInputCol)
}
if (model.isSet(model.inputCols)) {
m.setInputCols(model.getInputCols)
}
if (model.isSet(model.outputCol)) {
m.setOutputCol(model.getOutputCol)
}
if (model.isSet(model.outputCols)) {
m.setOutputCols(model.getOutputCols)
}
m
}

override def sparkInputs(obj: OneHotEncoderModel): Seq[ParamSpec] = Seq(ParamSpec("input", obj.inputCols))
override def sparkInputs(obj: OneHotEncoderModel): Seq[ParamSpec] = {
obj.isSet(obj.inputCol) match {
case true => Seq(ParamSpec("input", obj.inputCol))
case false => Seq(ParamSpec("input", obj.inputCols))
}

}

override def sparkOutputs(obj: OneHotEncoderModel): Seq[ParamSpec] = Seq(ParamSpec("output", obj.outputCols))
override def sparkOutputs(obj: OneHotEncoderModel): Seq[ParamSpec] = {
obj.isSet(obj.outputCol) match {
case true => Seq(ParamSpec("output", obj.outputCol))
case false => Seq(ParamSpec("output", obj.outputCols))
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.apache.spark.ml.parity.feature

import org.apache.spark.ml.bundle.SparkBundleContext
import org.apache.spark.ml.parity.SparkParityBase
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.ml.{Pipeline, Transformer}
Expand All @@ -22,4 +23,40 @@ class OneHotEncoderParitySpec extends SparkParityBase {
.fit(dataset)

override val unserializedParams = Set("stringOrderType")

it("serializes/deserializes the Spark model properly with one in/out column"){
bundleCache = None
val additionalIgnoreParams = Set("outputCol")
val pipeline = new Pipeline()
.setStages(Array(
new StringIndexer().setInputCol("state").setOutputCol("state_index"),
new OneHotEncoder().setInputCol("state_index").setOutputCol("state_oh")
)).fit(dataset)
val sparkTransformed = pipeline.transform(dataset)
implicit val sbc = SparkBundleContext().withDataset(sparkTransformed)
val deserializedTransformer = deserializedSparkTransformer(pipeline)
checkEquality(pipeline, deserializedTransformer, additionalIgnoreParams)
equalityTest(sparkTransformed, deserializedTransformer.transform(dataset))
bundleCache = None
}

it("fails to instantiate if the Spark model sets inputCol and inputCols"){
intercept[IllegalArgumentException] {
new OneHotEncoder()
.setInputCol("state")
.setInputCols(Array("state_index", "state_index2"))
.setOutputCols(Array("state_oh", "state_oh2"))
.fit(dataset)
}
}

it("fails to instantiate if the Spark model sets outputCol and outputCols"){
intercept[IllegalArgumentException] {
new OneHotEncoder()
.setInputCol("state")
.setOutputCol("state_oh")
.setOutputCols(Array("state_oh", "state_oh2"))
.fit(dataset)
}
}
}