Skip to content

Commit ac83d98

Browse files
Fixed an issue with OneHotEncoderOp when computing categorySizes
The spark feature org.apache.spark.ml.feature.OneHotEncoderModel has two mixins for the input columns: inputCol and inputCols. We need to check which param is set and use that correct one to compute categorySizes.
1 parent 8784e8e commit ac83d98

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,16 @@ class OneHotEncoderOp extends SimpleSparkOp[OneHotEncoderModel] {
4242
override def store(model: Model, obj: OneHotEncoderModel)
4343
(implicit context: BundleContext[SparkBundleContext]): Model = {
4444
assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz))
45-
46-
val df = context.context.dataset.get
47-
val categorySizes = obj.getInputCols.map { f OneHotEncoderOp.sizeForField(df.schema(f)) }
48-
45+
assert(!(obj.isSet(obj.inputCol) && obj.isSet(obj.inputCols)), "OneHotEncoderModel cannot have both inputCol and inputCols set")
46+
val inputCols = obj.isSet(obj.inputCol) match {
47+
case true => Array(obj.getInputCol)
48+
case false => obj.getInputCols
49+
}
50+
val df = context.context.dataset.get.schema(f)
51+
val categorySizes = inputCols.map { f OneHotEncoderOp.sizeForField(df.schema(f)) }
4952
model.withValue("category_sizes", Value.intList(categorySizes))
5053
.withValue("drop_last", Value.boolean(obj.getDropLast))
5154
.withValue("handle_invalid", Value.string(obj.getHandleInvalid))
52-
5355
}
5456

5557
override def load(model: Model)

0 commit comments

Comments
 (0)