Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"image": "mcr.microsoft.com/devcontainers/universal:2",
"features": {
"ghcr.io/devcontainers-extra/features/scala-sdkman:2": {
"version": "2.12.18",
"jdkVersion": "11"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object FileUtil {
def rmRF(toRemove: Path): Array[(String, Boolean)] = {
def removeElement(path: Path): (String, Boolean) = {
val result = Try {
Files.deleteIfExists(toRemove)
Files.deleteIfExists(path)
} match {
case Failure(_) => false
case Success(value) => value
Expand Down
Original file line number Diff line number Diff line change
@@ -1,40 +1,62 @@
package ml.combust.mleap.core.feature

import ml.combust.mleap.core.Model
import ml.combust.mleap.core.types.{ScalarType, StructType}
import ml.combust.mleap.core.types.{ScalarType, StructField, StructType}

/** Class for string indexer model.
*
* String indexer converts a string into an integer representation.
*
* @param labels list of labels that can be indexed
* @param labelsArray Array of ordered list of labels, corresponding to indices to be assigned for each input
* @param handleInvalid how to handle invalid values (unseen or NULL labels): 'error' (throw an error),
* 'skip' (skips invalid data)
* or 'keep' (put invalid data in a special bucket at index labels.size
*/
case class StringIndexerModel(labels: Seq[String],
handleInvalid: HandleInvalid = HandleInvalid.Error) extends Model {
val stringToIndex: Map[String, Int] = labels.zipWithIndex.toMap
case class StringIndexerModel(labelsArray: Array[Array[String]],
handleInvalid: HandleInvalid) extends Model {

private val stringToIndex: Array[Map[String, Int]] = labelsArray.map(_.zipWithIndex.toMap)
private val keepInvalid = handleInvalid == HandleInvalid.Keep
private val invalidValue = labels.length
private val invalidValue = labelsArray.map(_.length)



@deprecated("Use labelsArray instead")
def labels: Seq[String] = labelsArray(0).toSeq

/** Convert all strings into its integer representation.
*
* @param values labels to index
* @return indexes of labels
*/
def apply(values: Seq[Any]): Seq[Double] = values.zipWithIndex.map {
case (v: Any, i: Int) => encoder(v, i).toDouble
case (null, i: Int) => encoder(null, i).toDouble
}

def contains(values: Seq[Any]): Boolean = {
values.zipWithIndex.forall {
case (key, i) => stringToIndex(i).contains(key.toString)
}
}
/** Convert a string into its integer representation.
*
* @param value label to index
* @return index of label
*/
def apply(value: Any): Int = if (value == null) {
*
* @param value label to index
* @return index of label
*/
private def encoder(value: Any, colIdx: Int): Int = if (value == null) {
if (keepInvalid) {
invalidValue
invalidValue(colIdx)
} else {
throw new NullPointerException("StringIndexer encountered NULL value. " +
s"To handle NULLS, set handleInvalid to ${HandleInvalid.Keep.asParamString}")
}
} else {
val label = value.toString
stringToIndex.get(label) match {
stringToIndex(colIdx).get(label) match {
case Some(v) => v
case None => if (keepInvalid) {
invalidValue
invalidValue(colIdx)
} else {
throw new NoSuchElementException(s"Unseen label: $label. To handle unseen labels, " +
s"set handleInvalid to ${HandleInvalid.Keep.asParamString}")
Expand All @@ -43,12 +65,28 @@ case class StringIndexerModel(labels: Seq[String],
}

/** Create a [[ml.combust.mleap.core.feature.ReverseStringIndexerModel]] from this model.
*
* ReverseStringIndexer only support one input
* @return reverse string indexer of this string indexer
*/
def toReverse: ReverseStringIndexerModel = ReverseStringIndexerModel(labels)
def toReverse: ReverseStringIndexerModel = ReverseStringIndexerModel(labelsArray(0))

override def inputSchema: StructType = {
val f = labelsArray.zipWithIndex.map {
case (_, i) => StructField(s"input$i", ScalarType.String)
}
StructType(f).get
}

override def inputSchema: StructType = StructType("input" -> ScalarType.String).get
override def outputSchema: StructType = {
val f = labelsArray.zipWithIndex.map {
case (_, i) => StructField(s"output$i", ScalarType.Double.nonNullable)
}
StructType(f).get
}
}

override def outputSchema: StructType = StructType("output" -> ScalarType.Double.nonNullable).get
object StringIndexerModel {
def apply(labels: Seq[String], handleInvalid: HandleInvalid): StringIndexerModel = StringIndexerModel(Array(labels.toArray), handleInvalid)
def apply(labels: Seq[String]): StringIndexerModel = StringIndexerModel(Array(labels.toArray), HandleInvalid.Error)
def apply(labelsArray: Array[Array[String]]): StringIndexerModel = StringIndexerModel(labelsArray, HandleInvalid.Error)
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table
it("returns the index of the string") {
val indexer = StringIndexerModel(Array("hello", "there", "dude"))

assert(indexer("hello") == 0.0)
assert(indexer("there") == 1.0)
assert(indexer("dude") == 2.0)
assert(indexer(Seq("hello")).head == 0.0)
assert(indexer(Seq("there")).head == 1.0)
assert(indexer(Seq("dude")).head == 2.0)
}

it("throws NullPointerException when encounters NULL/None and handleInvalid is not keep") {
Expand All @@ -24,11 +24,11 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table

it("throws NoSuchElementException when encounters unseen label and handleInvalid is not keep") {
val indexer = StringIndexerModel(Array("hello"))
val unseenLabels = Table("unknown1", "unknown2")
val unseenLabels = Table("label", "unknown1", "unknown2")

forAll(unseenLabels) { (label: Any) =>
intercept[NoSuchElementException] {
indexer(label)
indexer(Seq(label))
}
}
}
Expand All @@ -38,7 +38,7 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table
val invalidLabels = Table("unknown", "other unknown", null, None)

forAll(invalidLabels) { (label: Any) =>
assert(indexer(label) == 3.0)
assert(indexer(Seq(label)).head == 3.0)
}
}
}
Expand All @@ -47,11 +47,11 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table
val indexer = StringIndexerModel(Array("hello", "there", "dude"))

it("has the right input schema") {
assert(indexer.inputSchema.fields == Seq(StructField("input", ScalarType.String)))
assert(indexer.inputSchema.fields == Seq(StructField("input0", ScalarType.String)))
}

it("has the right output schema") {
assert(indexer.outputSchema.fields == Seq(StructField("output", ScalarType.Double.nonNullable)))
assert(indexer.outputSchema.fields == Seq(StructField("output0", ScalarType.Double.nonNullable)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,28 @@ class StringIndexerOp extends MleapOp[StringIndexer, StringIndexerModel] {

override def store(model: Model, obj: StringIndexerModel)
(implicit context: BundleContext[MleapContext]): Model = {
model.
val m = model.
withValue("labels_length", Value.int(1)).
withValue("labels_array_0", Value.stringList(obj.labels)).
withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString))

obj.labelsArray.zipWithIndex.foldLeft(m){
case (m, (labels, i)) => m.withValue(s"labels_array_$i", Value.stringList(labels))
}
}

override def load(model: Model)
(implicit context: BundleContext[MleapContext]): StringIndexerModel = {
val handleInvalid = model.getValue("handle_invalid").map(_.getString).map(HandleInvalid.fromString(_)).getOrElse(HandleInvalid.default)
val label_length = model.getValue("labels_length").map(_.getInt).getOrElse(-1)
val labels: Seq[String] = label_length match {
val labelsArray: Array[Array[String]] = label_length match {
case -1 =>
// backawards compatibility with spark v2
model.value("labels").getStringList
case 1 => model.value("labels_array_0").getStringList
case _ => throw new UnsupportedOperationException("Multi-input StringIndexer not supported yet.")
Array(model.value("labels").getStringList.toArray)
case _ => (0 until label_length).map(i=>model.value(s"labels_array_$i").getStringList.toArray).toArray
}
StringIndexerModel(labels = labels, handleInvalid = handleInvalid)
StringIndexerModel(labelsArray = labelsArray, handleInvalid = handleInvalid)
}
}

override def model(node: StringIndexer): StringIndexerModel = node.model

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package ml.combust.mleap.runtime.transformer.feature

import ml.combust.mleap.core.types._
import ml.combust.mleap.runtime.frame.Transformer

/*
This Transformer trait only used when the Spark Transformer has inputCol and inputCols
, outputCol and outputCols. Because the saved shape will have port "input" instead of
"input0", "output" instead "output0".
*/
trait MultiInOutTransformer extends Transformer {
override def inputSchema: StructType = {
if (shape.getInput("input").isDefined) {
val fields = model.inputSchema.getField("input0").map {
case StructField(_, dataType) => StructField(shape.input("input").name, dataType)
}.toSeq
StructType(fields).get
} else {
super.inputSchema
}
}

override def outputSchema: StructType = {
if (shape.getOutput("output").isDefined) {
val fields = model.outputSchema.getField("output0").map {
case StructField(_, dataType) => StructField(shape.output("output").name, dataType)
}.toSeq
StructType(fields).get
} else {
super.outputSchema
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,39 @@ package ml.combust.mleap.runtime.transformer.feature

import ml.combust.mleap.core.feature.{HandleInvalid, StringIndexerModel}
import ml.combust.mleap.core.types._
import ml.combust.mleap.runtime.function.{FieldSelector, UserDefinedFunction}
import ml.combust.mleap.runtime.frame.{FrameBuilder, Transformer}
import ml.combust.mleap.runtime.function.{StructSelector, UserDefinedFunction}
import ml.combust.mleap.runtime.frame.{FrameBuilder, Row, Transformer}

import scala.util.Try

/**
* Created by hwilkins on 10/22/15.
*/
* Created by hwilkins on 10/22/15.
*/
case class StringIndexer(override val uid: String = Transformer.uniqueName("string_indexer"),
override val shape: NodeShape,
override val model: StringIndexerModel) extends Transformer {
val input: String = inputSchema.fields.head.name
val inputSelector: FieldSelector = input
val output: String = outputSchema.fields.head.name
val exec: UserDefinedFunction = (value: String) => model(value).toDouble
override val model: StringIndexerModel) extends Transformer with MultiInOutTransformer {
private val outputs: Seq[String] = outputSchema.fields.map(_.name)
private val inputs: Seq[String] = inputSchema.fields.map(_.name)
private val inputSelector: StructSelector = StructSelector(inputs)
private val filterSchema = StructType(Seq(StructField("output", ScalarType.Boolean.nonNullable))).get
private val exec: UserDefinedFunction = UserDefinedFunction((keys: Row) => {
val res = model(keys.toSeq)
Row(res:_*)
}, SchemaSpec(outputSchema), Seq(SchemaSpec(inputSchema)))
private val contains: UserDefinedFunction = UserDefinedFunction((keys: Row) => {
model.contains(keys.toSeq)
}, SchemaSpec(filterSchema), Seq(SchemaSpec(inputSchema)))

override def transform[FB <: FrameBuilder[FB]](builder: FB): Try[FB] = {
def withColumns(builder: FB): Try[FB] = {
builder.withColumns(outputs, inputSelector)(exec)
}

if(model.handleInvalid == HandleInvalid.Skip) {
builder.filter(input) {
(key: String) => model.stringToIndex.contains(key)
}.flatMap(_.withColumn(output, inputSelector)(exec))
builder.filter(inputSelector)(contains)
.flatMap(withColumns)
} else {
builder.withColumn(output, inputSelector)(exec)
withColumns(builder)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ml.combust.mleap.runtime.javadsl;

import ml.combust.mleap.core.feature.HandleInvalid$;
import ml.combust.mleap.core.feature.StringIndexerModel;
import ml.combust.mleap.core.feature.StringIndexerModel$;
import ml.combust.mleap.core.types.*;
import ml.combust.mleap.runtime.MleapContext;
import ml.combust.mleap.runtime.frame.DefaultLeapFrame;
Expand Down Expand Up @@ -52,7 +52,7 @@ private static Map<String, Double> createMap() {
new NodeShape(new ListMap<>(), new ListMap<>()).
withStandardInput("string").
withStandardOutput("string_index"),
new StringIndexerModel(JavaConversions.asScalaBuffer(Collections.singletonList("hello")).toSeq(),
StringIndexerModel$.MODULE$.apply(JavaConversions.asScalaBuffer(Collections.singletonList("hello")).toSeq(),
HandleInvalid$.MODULE$.fromString("error", true)));

DefaultLeapFrame buildFrame() {
Expand Down
Loading