File tree Expand file tree Collapse file tree 2 files changed +26
-0
lines changed
mleap-spark-extension/src
main/scala/org/apache/spark/ml/mleap/feature
test/scala/org/apache/spark/ml/mleap/parity/feature Expand file tree Collapse file tree 2 files changed +26
-0
lines changed Original file line number Diff line number Diff line change @@ -33,6 +33,12 @@ class MathBinary(override val uid: String = Identifiable.randomUID("math_binary"
3333
3434 @ org.apache.spark.annotation.Since (" 2.0.0" )
3535 override def transform (dataset : Dataset [_]): DataFrame = {
36+ // Check this condition at runtime else the input schema inferred by MathBinaryModel might be wrong
37+ // and this will cause the transform to fail at inference time
38+ if ((isSet(inputA) && model.da.isDefined) || (isSet(inputB) && model.db.isDefined)) {
39+ throw new RuntimeException (" Only one of input column or default value can be present." )
40+ }
41+
3642 val binaryUdfA = udf {
3743 a : Double => model(Some (a), None )
3844 }
Original file line number Diff line number Diff line change @@ -22,5 +22,25 @@ class MathBinaryParitySpec extends SparkParityBase {
2222 setOutputCol(" bin_out" )
2323 )).fit(dataset)
2424
25+ describe(" has valid inputs" ) {
26+ val model = MathBinaryModel (Multiply , da= Some (4.0 ), db= Some (5.0 ))
27+ it(" Only one of inputA or defaultA" ) {
28+ val invalidSparkTransformer : Transformer = new MathBinary (uid = " math_bin" , model = model).
29+ setInputA(" dti" ).
30+ setOutputCol(" bin_out" )
31+ assertThrows[RuntimeException ] {
32+ invalidSparkTransformer.transform(dataset)
33+ }
34+ }
35+ it(" Only one of inputB or defaultB" ) {
36+ val invalidSparkTransformer : Transformer = new MathBinary (uid = " math_bin" , model = model).
37+ setInputB(" dti" ).
38+ setOutputCol(" bin_out" )
39+ assertThrows[RuntimeException ] {
40+ invalidSparkTransformer.transform(dataset)
41+ }
42+ }
43+ }
44+
2545 override val unserializedParams = Set (" stringOrderType" )
2646}
You can’t perform that action at this time.
0 commit comments