Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DecisionTreeClassifierOp extends MleapOp[DecisionTreeClassifier, DecisionT

override def store(model: Model, obj: DecisionTreeClassifierModel)
(implicit context: BundleContext[MleapContext]): Model = {
TreeSerializer[tree.Node](context.file("nodes"), withImpurities = true).write(obj.rootNode)
TreeSerializer[tree.Node](context.file("tree"), withImpurities = true).write(obj.rootNode)
model.withValue("num_features", Value.long(obj.numFeatures)).
withValue("num_classes", Value.long(obj.numClasses)).
withValue("thresholds", obj.thresholds.map(Value.doubleList(_)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DecisionTreeRegressionOp extends MleapOp[DecisionTreeRegression, DecisionT

override def store(model: Model, obj: DecisionTreeRegressionModel)
(implicit context: BundleContext[MleapContext]): Model = {
TreeSerializer[tree.Node](context.file("nodes"), withImpurities = false).write(obj.rootNode)
TreeSerializer[tree.Node](context.file("tree"), withImpurities = false).write(obj.rootNode)
model.withValue("num_features", Value.long(obj.numFeatures))
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
package ml.combust.mleap.runtime.transformer.classification

import ml.combust.bundle.BundleFile
import ml.combust.bundle.serializer.SerializationFormat
import ml.combust.mleap.core.classification.DecisionTreeClassifierModel
import ml.combust.mleap.core.tree.{ContinuousSplit, InternalNode, LeafNode}
import ml.combust.mleap.core.types._
import ml.combust.mleap.runtime.MleapSupport._
import ml.combust.mleap.runtime.test.TestUtil
import org.scalatest.FunSpec
import resource.managed

import java.io.File
import java.net.URI

class DecisionTreeClassifierSpec extends FunSpec {

Expand Down Expand Up @@ -45,4 +54,32 @@ class DecisionTreeClassifierSpec extends FunSpec {
StructField("prediction", ScalarType.Double.nonNullable)))
}
}

describe("save/load model") {
it("correctly reproduces the model when saved and loaded") {
val node = InternalNode(LeafNode(Seq(0.78)), LeafNode(Seq(0.34)), ContinuousSplit(0, 0.5))
val transformer = DecisionTreeClassifier(shape = NodeShape.probabilisticClassifier(rawPredictionCol = Some("rp"),
probabilityCol = Some("probability")),
model = new DecisionTreeClassifierModel(node, 3, 2))

// serialization
val fileName = s"${TestUtil.baseDir}/decision_tree_classifier_saved_model.json.zip"
val uri = new URI(s"jar:file:$fileName")
for (file <- managed(BundleFile(uri))) {
transformer.writeBundle.name("bundle")
.format(SerializationFormat.Json)
.save(file)
}

// de-serialization
val file = new File(fileName)
val loadedTransformer = (for (bf <- managed(BundleFile(file))) yield {
bf.loadMleapBundle().get.root
}).tried.get.asInstanceOf[DecisionTreeClassifier]

// checks
assert(transformer.inputSchema.fields equals (loadedTransformer.inputSchema.fields))
assert(transformer.outputSchema.fields equals (loadedTransformer.outputSchema.fields))
}
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package ml.combust.mleap.runtime.transformer.regression

import ml.combust.bundle.BundleFile
import ml.combust.bundle.serializer.SerializationFormat
import ml.combust.mleap.core.regression.DecisionTreeRegressionModel
import ml.combust.mleap.core.tree.{ContinuousSplit, InternalNode, LeafNode}
import ml.combust.mleap.core.types._
import ml.combust.mleap.runtime.MleapSupport._
import ml.combust.mleap.runtime.test.TestUtil
import org.scalatest.FunSpec
import resource.managed

import java.io.File
import java.net.URI

class DecisionTreeRegressionSpec extends FunSpec {

Expand All @@ -20,4 +28,33 @@ class DecisionTreeRegressionSpec extends FunSpec {
StructField("prediction", ScalarType.Double.nonNullable)))
}
}

describe("save/load model") {
it("correctly reproduces the model when saved and loaded") {
val node = InternalNode(LeafNode(Seq(0.78)), LeafNode(Seq(0.34)), ContinuousSplit(0, 0.5))
val regression = DecisionTreeRegressionModel(node, 3)

val transformer = DecisionTreeRegression(shape = NodeShape.regression(),
model = regression)

// serialization
val fileName = s"${TestUtil.baseDir}/decision_tree_regression_saved_model.json.zip"
val uri = new URI(s"jar:file:$fileName")
for (file <- managed(BundleFile(uri))) {
transformer.writeBundle.name("bundle")
.format(SerializationFormat.Json)
.save(file)
}

// de-serialization
val file = new File(fileName)
val loadedTransformer = (for (bf <- managed(BundleFile(file))) yield {
bf.loadMleapBundle().get.root
}).tried.get.asInstanceOf[DecisionTreeRegression]

// checks
assert(transformer.inputSchema.fields equals (loadedTransformer.inputSchema.fields))
assert(transformer.outputSchema.fields equals (loadedTransformer.outputSchema.fields))
}
}
}