Skip to content

Commit 609eef9

Browse files
committed
use labelsArray instead of lables for StringIndexer
1 parent 65d672c commit 609eef9

File tree

6 files changed

+36
-24
lines changed

6 files changed

+36
-24
lines changed

mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,22 @@ import ml.combust.mleap.core.types.{ScalarType, StructField, StructType}
77
*
88
* String indexer converts a string into an integer representation.
99
*
10-
* @param labels list of labels that can be indexed
10+
* @param labelsArray Array of ordered list of labels, corresponding to indices to be assigned for each input
1111
* @param handleInvalid how to handle invalid values (unseen or NULL labels): 'error' (throw an error),
1212
* 'skip' (skips invalid data)
1313
* or 'keep' (put invalid data in a special bucket at index labels.size
1414
*/
15-
case class StringIndexerModel(labels: Seq[Seq[String]],
16-
handleInvalid: HandleInvalid = HandleInvalid.Error) extends Model {
17-
private val stringToIndex: Array[Map[String, Int]] = labels.map(_.zipWithIndex.toMap).toArray
15+
case class StringIndexerModel(labelsArray: Array[Array[String]],
16+
handleInvalid: HandleInvalid) extends Model {
17+
18+
private val stringToIndex: Array[Map[String, Int]] = labelsArray.map(_.zipWithIndex.toMap)
1819
private val keepInvalid = handleInvalid == HandleInvalid.Keep
19-
private val invalidValue = labels.map(_.length)
20+
private val invalidValue = labelsArray.map(_.length)
21+
22+
23+
24+
@deprecated("Use labelsArray instead")
25+
def labels: Seq[String] = labelsArray(0).toSeq
2026

2127
/** Convert all strings into its integer representation.
2228
*
@@ -59,22 +65,28 @@ case class StringIndexerModel(labels: Seq[Seq[String]],
5965
}
6066

6167
/** Create a [[ml.combust.mleap.core.feature.ReverseStringIndexerModel]] from this model.
62-
*
68+
* ReverseStringIndexer only support one input
6369
* @return reverse string indexer of this string indexer
6470
*/
65-
// def toReverse: ReverseStringIndexerModel = ReverseStringIndexerModel(labels)
71+
def toReverse: ReverseStringIndexerModel = ReverseStringIndexerModel(labelsArray(0))
6672

6773
override def inputSchema: StructType = {
68-
val f = labels.zipWithIndex.map {
74+
val f = labelsArray.zipWithIndex.map {
6975
case (_, i) => StructField(s"input$i", ScalarType.String)
7076
}
7177
StructType(f).get
7278
}
7379

7480
override def outputSchema: StructType = {
75-
val f = labels.zipWithIndex.map {
81+
val f = labelsArray.zipWithIndex.map {
7682
case (_, i) => StructField(s"output$i", ScalarType.Double.nonNullable)
7783
}
7884
StructType(f).get
7985
}
8086
}
87+
88+
object StringIndexerModel {
89+
def apply(labels: Seq[String], handleInvalid: HandleInvalid): StringIndexerModel = StringIndexerModel(Array(labels.toArray), handleInvalid)
90+
def apply(labels: Seq[String]): StringIndexerModel = StringIndexerModel(Array(labels.toArray), HandleInvalid.Error)
91+
def apply(labelsArray: Array[Array[String]]): StringIndexerModel = StringIndexerModel(labelsArray, HandleInvalid.Error)
92+
}

mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@ import org.scalatest.prop.TableDrivenPropertyChecks
1010
class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with TableDrivenPropertyChecks {
1111
describe("#apply") {
1212
it("returns the index of the string") {
13-
val indexer = StringIndexerModel(Seq(Array("hello", "there", "dude")))
13+
val indexer = StringIndexerModel(Seq("hello", "there", "dude"))
1414

1515
assert(indexer(Seq("hello")).head == 0.0)
1616
assert(indexer(Seq("there")).head == 1.0)
1717
assert(indexer(Seq("dude")).head == 2.0)
1818
}
1919

2020
it("throws NullPointerException when encounters NULL/None and handleInvalid is not keep") {
21-
val indexer = StringIndexerModel(Seq(Array("hello")))
21+
val indexer = StringIndexerModel(Array(Array("hello")))
2222
assertThrows[NullPointerException](indexer(null))
2323
}
2424

2525
it("throws NoSuchElementException when encounters unseen label and handleInvalid is not keep") {
26-
val indexer = StringIndexerModel(Seq(Array("hello")))
26+
val indexer = StringIndexerModel(Array(Array("hello")))
2727
val unseenLabels = Table("label", "unknown1", "unknown2")
2828

2929
forAll(unseenLabels) { (label: Any) =>
@@ -34,7 +34,7 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table
3434
}
3535

3636
it("returns default index for HandleInvalid.keep mode") {
37-
val indexer = StringIndexerModel(Seq(Array("hello", "there", "dude")), handleInvalid = HandleInvalid.Keep)
37+
val indexer = StringIndexerModel(Seq("hello", "there", "dude"), handleInvalid = HandleInvalid.Keep)
3838
val invalidLabels = Table("unknown", "other unknown", null, None)
3939

4040
forAll(invalidLabels) { (label: Any) =>
@@ -44,7 +44,7 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table
4444
}
4545

4646
describe("input/output schema") {
47-
val indexer = StringIndexerModel(Seq(Array("hello", "there", "dude")))
47+
val indexer = StringIndexerModel(Seq("hello", "there", "dude"))
4848

4949
it("has the right input schema") {
5050
assert(indexer.inputSchema.fields == Seq(StructField("input0", ScalarType.String)))

mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,22 @@ class StringIndexerOp extends MleapOp[StringIndexer, StringIndexerModel] {
2222
val m = model.
2323
withValue("labels_length", Value.int(1)).
2424
withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString))
25-
obj.labels.zipWithIndex.foldLeft(m){
26-
case (m, (label, i)) => m.withValue(s"labels_array_$i", Value.stringList(label))
25+
obj.labelsArray.zipWithIndex.foldLeft(m){
26+
case (m, (labels, i)) => m.withValue(s"labels_array_$i", Value.stringList(labels))
2727
}
2828
}
2929

3030
override def load(model: Model)
3131
(implicit context: BundleContext[MleapContext]): StringIndexerModel = {
3232
val handleInvalid = model.getValue("handle_invalid").map(_.getString).map(HandleInvalid.fromString(_)).getOrElse(HandleInvalid.default)
3333
val label_length = model.getValue("labels_length").map(_.getInt).getOrElse(-1)
34-
val labels: Seq[Seq[String]] = label_length match {
34+
val labelsArray: Array[Array[String]] = label_length match {
3535
case -1 =>
3636
// backawards compatibility with spark v2
37-
Seq(model.value("labels").getStringList)
38-
case _ => (0 until label_length).map(i=>model.value(s"labels_array_$i").getStringList)
37+
Array(model.value("labels").getStringList.toArray)
38+
case _ => (0 until label_length).map(i=>model.value(s"labels_array_$i").getStringList.toArray).toArray
3939
}
40-
StringIndexerModel(labels = labels, handleInvalid = handleInvalid)
40+
StringIndexerModel(labelsArray = labelsArray, handleInvalid = handleInvalid)
4141
}
4242
}
4343

mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class MleapSupportSpec extends org.scalatest.funspec.AnyFunSpec {
1616
private val stringIndexer = StringIndexer(shape = NodeShape().
1717
withStandardInput("feature").
1818
withStandardOutput("feature_index"),
19-
model = StringIndexerModel(Seq(Seq("label1", "label2"))))
19+
model = StringIndexerModel(Seq("label1", "label2")))
2020

2121
describe("URIBundleFileOps") {
2222
it("can save/load a bundle using a URI") {

mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package ml.combust.mleap.runtime.javadsl;
22

33
import ml.combust.mleap.core.feature.HandleInvalid$;
4-
import ml.combust.mleap.core.feature.StringIndexerModel;
4+
import ml.combust.mleap.core.feature.StringIndexerModel$;
55
import ml.combust.mleap.core.types.*;
66
import ml.combust.mleap.runtime.MleapContext;
77
import ml.combust.mleap.runtime.frame.DefaultLeapFrame;
@@ -52,7 +52,7 @@ private static Map<String, Double> createMap() {
5252
new NodeShape(new ListMap<>(), new ListMap<>()).
5353
withStandardInput("string").
5454
withStandardOutput("string_index"),
55-
new StringIndexerModel(JavaConversions.asScalaBuffer(Collections.singletonList(JavaConversions.asScalaBuffer(Collections.singletonList("hello")).toSeq())).toSeq(),
55+
StringIndexerModel$.MODULE$.apply(JavaConversions.asScalaBuffer(Collections.singletonList("hello")).toSeq(),
5656
HandleInvalid$.MODULE$.fromString("error", true)));
5757

5858
DefaultLeapFrame buildFrame() {

mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class StringIndexerSpec extends org.scalatest.funspec.AnyFunSpec {
1818
outputPort="output0",
1919
inputCol = "test_string",
2020
outputCol = "test_index"),
21-
model = StringIndexerModel(Seq(Seq("index1", "index2", "index3"))))
21+
model = StringIndexerModel(Seq("index1", "index2", "index3")))
2222

2323
describe("#transform") {
2424
it("converts input string into an index") {

0 commit comments

Comments
 (0)