Skip to content

Commit 7fc9694

Browse files
authored
Merge pull request #841 from praj-0/u/praj/math_binary_input_validation
Allow only one of input or default in MathBinary
2 parents fb43440 + e993d3e commit 7fc9694

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

mleap-spark-extension/src/main/scala/org/apache/spark/ml/mleap/feature/MathBinary.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ class MathBinary(override val uid: String = Identifiable.randomUID("math_binary"
3333

3434
@org.apache.spark.annotation.Since("2.0.0")
3535
override def transform(dataset: Dataset[_]): DataFrame = {
36+
// Check this condition at runtime else the input schema inferred by MathBinaryModel might be wrong
37+
// and this will cause the transform to fail at inference time
38+
if((isSet(inputA) && model.da.isDefined) || (isSet(inputB) && model.db.isDefined)) {
39+
throw new RuntimeException("Only one of input column or default value can be present.")
40+
}
41+
3642
val binaryUdfA = udf {
3743
a: Double => model(Some(a), None)
3844
}

mleap-spark-extension/src/test/scala/org/apache/spark/ml/mleap/parity/feature/MathBinaryParitySpec.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,25 @@ class MathBinaryParitySpec extends SparkParityBase {
2222
setOutputCol("bin_out")
2323
)).fit(dataset)
2424

25+
describe("has valid inputs") {
26+
val model = MathBinaryModel(Multiply, da=Some(4.0), db=Some(5.0))
27+
it("Only one of inputA or defaultA") {
28+
val invalidSparkTransformer: Transformer = new MathBinary(uid = "math_bin", model = model).
29+
setInputA("dti").
30+
setOutputCol("bin_out")
31+
assertThrows[RuntimeException] {
32+
invalidSparkTransformer.transform(dataset)
33+
}
34+
}
35+
it("Only one of inputB or defaultB") {
36+
val invalidSparkTransformer: Transformer = new MathBinary(uid = "math_bin", model = model).
37+
setInputB("dti").
38+
setOutputCol("bin_out")
39+
assertThrows[RuntimeException] {
40+
invalidSparkTransformer.transform(dataset)
41+
}
42+
}
43+
}
44+
2545
override val unserializedParams = Set("stringOrderType")
2646
}

0 commit comments

Comments
 (0)