-
Notifications
You must be signed in to change notification settings - Fork 315
Description
Problem description:
We were able to serialize the XGBoost model with MLeap using the older PySpark API (dmlc/xgboost#4656) as shown below:
import mleap.pyspark
from mleap.pyspark.spark_support import SimpleSparkSerializer
trans_model = model.transform(df)
local_path = "jar:file:/tmp/pyspark.model.zip"
model.serializeToBundle(local_path, trans_model)
But We are not able to do with official PySpark API (https://xgboost.readthedocs.io/en/stable/tutorials/spark_estimator.html) support. Is there anything flags that I need to use to make this work?
We use MLeap to store the model as a Serialized bundle and use it in the Java runtime enviroment for model serving.
Steps Followed:
Download spark
wget https://dlcdn.apache.org/spark/spark-3.3.3/spark-3.3.3-bin-hadoop3.tgz
tar -xvf spark-3.3.3-bin-hadoop3.tgz
Download xgboost jars
wget https://repo1.maven.org/maven2/ml/dmlc/xgboost4j_2.12/1.7.3/xgboost4j_2.12-1.7.3.jar
wget https://repo1.maven.org/maven2/ml/dmlc/xgboost4j-spark_2.12/1.7.3/xgboost4j-spark_2.12-1.7.3.jar
Build MLeap fat jar :
https://github.com/combust/mleap/blob/master/mleap-databricks-runtime-fat/README.md
git clone --recursive https://github.com/combust/mleap.git
cd mleap
git checkout tags/v0.22.0
sbt mleap-databricks-runtime-fat/assembly
cp mleap-databricks-runtime-fat/target/scala-2.12/mleap-databricks-runtime-fat-assembly-0.22.0.jar ../spark-3.3.3-bin-hadoop3/jars
Install python requirements
pip install mleap==0.22.0
pip install xgboost==1.7.3
pip install pyarrow
Running the example
cd ../spark-3.3.3-bin-hadoop3
./bin/spark-submit example.py
error log
23/09/13 22:06:29 INFO CodeGenerator: Code generated in 25.565617 ms
+---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+
|feat1|feat2|feat3|feat4|feat5|feat7|feat8| feat9|feat10|feat11|feat12|label| features|rawPrediction|prediction|probability|
+---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+
| 7| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[7.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
| 7| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[7.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
| 7| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[7.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
| 7| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[7.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
| 5| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[5.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
| 5| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[5.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
+---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+
Traceback (most recent call last):
File "/Users/s0a018g/local_setup/spark-3.3.3-bin-hadoop3/bin/code1.py", line 47, in
model.serializeToBundle(local_path, predictions)
File "/Users/s0a018g/opt/anaconda3/lib/python3.8/site-packages/mleap/pyspark/spark_support.py", line 25, in serializeToBundle
serializer.serializeToBundle(self, path, dataset=dataset)
File "/Users/s0a018g/opt/anaconda3/lib/python3.8/site-packages/mleap/pyspark/spark_support.py", line 42, in serializeToBundle
self._java_obj.serializeToBundle(transformer._to_java(), path, dataset._jdf)
File "/Users/s0a018g/local_setup/spark-3.3.3-bin-hadoop3/python/lib/pyspark.zip/pyspark/ml/pipeline.py", line 363, in _to_java
AttributeError: 'SparkXGBClassifierModel' object has no attribute '_to_java'
23/09/13 22:06:30 INFO SparkContext: Invoking stop() from shutdown hook
Thanks @agsachin for creating the instructions that can be easily reproduced on Mac.