Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,38 @@ train.close();
test.close();
```

## Custom objectives

LightGBM4j supports using custom objective functions, but it doesn't provide any high-level wrappers as python API does.

LightGBM needs a tuple of 1st and 2nd order derivatives (gradients and hessians) computed for each datapoint. With LightGBM4j it looks like this for an MSE metric:

```java
LGBMDataset dataset = LGBMDataset.createFromFile("cancer.csv", "header=true label=name:Classification", null);
LGBMBooster booster = LGBMBooster.create(dataset, "objective=binary label=name:Classification");
// actual ground truth label values
float y[] = dataset.getFieldFloat("label");

for (int it=0; it<10; it++) {
// predictions for current iteration
double[] yhat = booster.getPredict(0); // 0 - training dataset
float[] grad = new float[y.length];
float[] hess = new float[y.length];
for (int i=0; i<y.length; i++) {
// 1-st derivative of squared error
grad[i] = (float)(2 * (yhat[i]-y[i]));
// 2-nd derivative of squared error
hess[i] = (float)(0 * (yhat[i]-y[i]) + 2);
}
booster.updateOneIterCustom(grad, hess);
// print the computed average error
double[] err = booster.getEval(0);
System.out.println("it " + it + " err=" + err[0]);
}
booster.close();
dataset.close();
```

## Supported platforms

This code is tested to work well with Linux (Ubuntu 20.04), Windows (Server 2019) and MacOS 10.15/11. Mac M1 is also supported.
Expand All @@ -157,12 +189,16 @@ Supported methods:
* [LGBM_BoosterFeatureImportance](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterFeatureImportance)
* [LGBM_BoosterGetEvalNames](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetEvalNames)
* [LGBM_BoosterGetNumFeature](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetNumFeature)
* [LGBM_BoosterGetNumClasses](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetNumClasses)
* [LGBM_BoosterGetNumPredict](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetNumPredict)
* [LGBM_BoosterGetPredict](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetPredict)
* [LGBM_BoosterLoadModelFromString](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterLoadModelFromString)
* [LGBM_BoosterPredictForMat](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMat)
* [LGBM_BoosterPredictForMatSingleRow](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRow)
* [LGBM_BoosterSaveModel](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterSaveModel)
* [LGBM_BoosterSaveModelToString](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterSaveModelToString)
* [LGBM_BoosterUpdateOneIter](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterUpdateOneIter)
* [LGBM_BoosterUpdateOneIterCustom](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterUpdateOneIterCustom)
* [LGBM_DatasetCreateFromFile](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_DatasetCreateFromFile)
* [LGBM_DatasetCreateFromMat](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_DatasetCreateFromMat)
* [LGBM_DatasetFree](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_DatasetFree)
Expand All @@ -182,9 +218,6 @@ Not yet supported:
* [LGBM_BoosterGetEvalCounts](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetEvalCounts)
* [LGBM_BoosterGetLeafValue](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetLeafValue)
* [LGBM_BoosterGetLowerBoundValue](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetLowerBoundValue)
* [LGBM_BoosterGetNumClasses](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetNumClasses)
* [LGBM_BoosterGetNumPredict](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetNumPredict)
* [LGBM_BoosterGetPredict](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetPredict)
* [LGBM_BoosterGetUpperBoundValue](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetUpperBoundValue)
* [LGBM_BoosterMerge](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterMerge)
* [LGBM_BoosterNumberOfTotalModel](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterNumberOfTotalModel)
Expand All @@ -205,7 +238,6 @@ Not yet supported:
* [LGBM_BoosterRollbackOneIter](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterRollbackOneIter)
* [LGBM_BoosterSetLeafValue](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterSetLeafValue)
* [LGBM_BoosterShuffleModels](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterShuffleModels)
* [LGBM_BoosterUpdateOneIterCustom](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterUpdateOneIterCustom)
* [LGBM_DatasetAddFeaturesFrom](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_DatasetAddFeaturesFrom)
* [LGBM_DatasetCreateByReference](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_DatasetCreateByReference)
* [LGBM_DatasetCreateFromCSC](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_DatasetCreateFromCSC)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ protected SWIGTYPE_p_void() {
swigCPtr = 0;
}

protected static long getCPtr(SWIGTYPE_p_void obj) {
public static long getCPtr(SWIGTYPE_p_void obj) {
return (obj == null) ? 0 : obj.swigCPtr;
}
}
Expand Down
100 changes: 100 additions & 0 deletions src/main/java/io/github/metarank/lightgbm4j/LGBMBooster.java
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,105 @@ private int importanceType(FeatureImportanceType tpe) {
return importanceType;
}

/**
* Get number of classes.
* @return Number of classes
* @throws LGBMException
*/
public int getNumClasses() throws LGBMException {
SWIGTYPE_p_int numHandle = new_int32_tp();
int result = LGBM_BoosterGetNumClasses(voidpp_value(handle), numHandle);
if (result < 0) {
delete_intp(numHandle);
throw new LGBMException(LGBM_GetLastError());
} else {
int numClasses = intp_value(numHandle);
delete_intp(numHandle);
return numClasses;
}
}

/**
* Get number of predictions for training data and validation data (this can be used to support customized evaluation functions).
* @param dataIdx Index of data, 0: training data, 1: 1st validation data, 2: 2nd validation data and so on
* @return Number of predictions
* @throws LGBMException
*/
public long getNumPredict(int dataIdx) throws LGBMException {
SWIGTYPE_p_long_long numHandle = new_int64_tp();
int result = LGBM_BoosterGetNumPredict(voidpp_value(handle), dataIdx, numHandle);
if (result < 0) {
delete_int64_tp(numHandle);
throw new LGBMException(LGBM_GetLastError());
} else {
long numClasses = int64_tp_value(numHandle);
delete_int64_tp(numHandle);
return numClasses;
}
}

/**
* Get prediction for training data and validation data.
* @param dataIdx Index of data, 0: training data, 1: 1st validation data, 2: 2nd validation data and so on
* @return array with predictions, of size num_class * dataset.num_data
* @throws LGBMException
*/
public double[] getPredict(int dataIdx) throws LGBMException {
int allocatedSize = getNumClasses() * (int)getNumPredict(dataIdx);
SWIGTYPE_p_double buffer = new_doubleArray(allocatedSize);
SWIGTYPE_p_long_long size = new_int64_tp();
int result = LGBM_BoosterGetPredict(voidpp_value(handle), dataIdx, size, buffer);
if (result < 0) {
delete_doubleArray(buffer);
delete_int64_tp(size);
throw new LGBMException(LGBM_GetLastError());
} else {
double[] out = new double[(int)int64_tp_value(size)];
for (int i=0; i<out.length; i++) {
out[i] = doubleArray_getitem(buffer, i);
}
delete_doubleArray(buffer);
delete_int64_tp(size);
return out;
}
}

/**
* Update the model by specifying gradient and Hessian directly (this can be used to support customized loss functions).
* The length of the arrays referenced by grad and hess must be equal to num_class * num_train_data, this is not
* verified by the library, the caller must ensure this.
*
* @param grad The first order derivative (gradient) statistics
* @param hess The second order derivative (Hessian) statistics
* @return true means the update was successfully finished (cannot split anymore), false indicates failure
* @throws LGBMException
*/
public boolean updateOneIterCustom(float[] grad, float[] hess) throws LGBMException {
SWIGTYPE_p_float gradHandle = new_floatArray(grad.length);
for (int i=0; i<grad.length; i++) {
floatArray_setitem(gradHandle, i, grad[i]);
}
SWIGTYPE_p_float hessHandle = new_floatArray(hess.length);
for (int i=0; i<hess.length; i++) {
floatArray_setitem(hessHandle, i, hess[i]);
}
SWIGTYPE_p_int isFinishedHandle = new_intp();
int result = LGBM_BoosterUpdateOneIterCustom(voidpp_value(handle), gradHandle, hessHandle, isFinishedHandle);
if (result < 0) {
delete_floatArray(gradHandle);
delete_floatArray(hessHandle);
delete_intp(isFinishedHandle);
throw new LGBMException(LGBM_GetLastError());
} else {
int isFinished = intp_value(isFinishedHandle);
delete_floatArray(gradHandle);
delete_floatArray(hessHandle);
delete_intp(isFinishedHandle);
return isFinished == 1;
}
}


/**
* Calculates the output buffer size for the different prediction types. See the notes at:
* <a href="https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMat">predictForMat</a> &
Expand All @@ -592,4 +691,5 @@ else if (PredictionType.C_API_PREDICT_LEAF_INDEX.equals(predictionType))
else // for C_API_PREDICT_NORMAL & C_API_PREDICT_RAW_SCORE
return defaultSize;
}

}
78 changes: 78 additions & 0 deletions src/main/java/io/github/metarank/lightgbm4j/LGBMDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,83 @@ public String[] getFeatureNames() throws LGBMException {
return names;
}

/**
* Get float[] field from the dataset.
* @param field Field name
* @return
* @throws LGBMException
*/
public float[] getFieldFloat(String field) throws LGBMException {
SWIGTYPE_p_int lenPtr = new_intp();
SWIGTYPE_p_p_void bufferPtr = new_voidpp();
SWIGTYPE_p_int typePtr = new_intp();
int result = LGBM_DatasetGetField(handle, field, lenPtr, bufferPtr, typePtr);
if (result < 0) {
delete_intp(lenPtr);
delete_voidpp(bufferPtr);
delete_intp(typePtr);
throw new LGBMException(LGBM_GetLastError());
} else {
int len = intp_value(lenPtr);
int type = intp_value(typePtr);
if (type == C_API_DTYPE_FLOAT32) {
SWIGTYPE_p_void buf = voidpp_value(bufferPtr);
float[] out = new float[len];
for (int i=0; i<len; i++) {
// Hello, this is Johny Knoxville, and today we're reading a raw void pointer as an array of floats
out[i] = lightgbmlibJNI.floatArray_getitem(SWIGTYPE_p_void.getCPtr(buf), i);
}
delete_intp(lenPtr);
delete_voidpp(bufferPtr);
delete_intp(typePtr);
return out;
} else {
delete_intp(lenPtr);
delete_voidpp(bufferPtr);
delete_intp(typePtr);
throw new LGBMException("getFieldFloat expects a float field (of ctype=" + C_API_DTYPE_FLOAT32 + ") but got ctype="+type);
}
}
}
/**
* Get int[] field from the dataset.
* @param field Field name
* @return
* @throws LGBMException
*/
public int[] getFieldInt(String field) throws LGBMException {
// a copy-paste from getFieldFloat with different types, for the sake of performance
SWIGTYPE_p_int lenPtr = new_intp();
SWIGTYPE_p_p_void bufferPtr = new_voidpp();
SWIGTYPE_p_int typePtr = new_intp();
int result = LGBM_DatasetGetField(handle, field, lenPtr, bufferPtr, typePtr);
if (result < 0) {
delete_intp(lenPtr);
delete_voidpp(bufferPtr);
delete_intp(typePtr);
throw new LGBMException(LGBM_GetLastError());
} else {
int len = intp_value(lenPtr);
int type = intp_value(typePtr);
if (type == C_API_DTYPE_INT32) {
SWIGTYPE_p_void buf = voidpp_value(bufferPtr);
int[] out = new int[len];
for (int i=0; i<len; i++) {
out[i] = lightgbmlibJNI.intArray_getitem(SWIGTYPE_p_void.getCPtr(buf), i);
}
delete_intp(lenPtr);
delete_voidpp(bufferPtr);
delete_intp(typePtr);
return out;
} else {
delete_intp(lenPtr);
delete_voidpp(bufferPtr);
delete_intp(typePtr);
throw new LGBMException("getFieldFloat expects a float field (of ctype=" + C_API_DTYPE_FLOAT32 + ") but got ctype="+type);
}
}
}

/**
* Deallocate all native memory for the LightGBM dataset.
* @throws LGBMException
Expand All @@ -251,4 +328,5 @@ public void close() throws LGBMException {
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public static Stream<Arguments> datasets() throws LGBMException, IOException {
Arguments.of(datasetReadmeExample())
);
}
private static LGBMDataset datasetFromFile() throws LGBMException, IOException {
public static LGBMDataset datasetFromFile() throws LGBMException, IOException {
LGBMDataset dataset = LGBMDataset.createFromFile("src/test/resources/cancer.csv", "header=true label=name:Classification", null);
return dataset;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.github.metarank.lightgbm4j;

import org.junit.jupiter.api.Test;

import java.io.IOException;

public class CustomObjectiveTest {

@Test
void testCancerCustomObjective() throws LGBMException, IOException {
LGBMDataset dataset = CancerIntegrationTest.datasetFromFile();
LGBMBooster booster = LGBMBooster.create(dataset, "objective=binary label=name:Classification");
// actual ground truth label values
float y[] = dataset.getFieldFloat("label");
for (int it=0; it<10; it++) {
// predictions for current iteration
double[] yhat = booster.getPredict(0);
float[] grad = new float[y.length];
float[] hess = new float[y.length];
for (int i=0; i<y.length; i++) {
// 1-st derivative of squared error
grad[i] = (float)(2 * (yhat[i]-y[i]));
// 2-nd derivative of squared error
hess[i] = (float)(0 * (yhat[i]-y[i]) + 2);
}
booster.updateOneIterCustom(grad, hess);
// print the computed average error
double[] err = booster.getEval(0);
System.out.println("it " + it + " err=" + err[0]);
}
booster.close();
dataset.close();
}
}
51 changes: 51 additions & 0 deletions src/test/java/io/github/metarank/lightgbm4j/LGBMBoosterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.Random;

import static org.junit.jupiter.api.Assertions.*;

public class LGBMBoosterTest {
Expand Down Expand Up @@ -293,12 +295,61 @@ void testCreateByReference() throws LGBMException {
assertEquals(train[0], test[0], 0.001);
}

@Test void testGetNumClasses() throws LGBMException {
LGBMDataset dataset = LGBMDataset.createFromFile("src/test/resources/cancer.csv", "header=true label=name:Classification", null);
LGBMBooster booster = LGBMBooster.create(dataset, "objective=binary label=name:Classification");
assertEquals(booster.getNumClasses(), 1);
dataset.close();
booster.close();
}

@Test void testGetNumPredict() throws LGBMException {
LGBMDataset dataset = LGBMDataset.createFromFile("src/test/resources/cancer.csv", "header=true label=name:Classification", null);
LGBMBooster booster = LGBMBooster.create(dataset, "objective=binary label=name:Classification");
assertEquals(booster.getNumPredict(0), 116);
dataset.close();
booster.close();
}

@Test void testGetPredict() throws LGBMException {
LGBMDataset dataset = LGBMDataset.createFromFile("src/test/resources/cancer.csv", "header=true label=name:Classification", null);
LGBMBooster booster = LGBMBooster.create(dataset, "objective=binary label=name:Classification");
booster.updateOneIter();
booster.updateOneIter();
booster.updateOneIter();
double[] preds = booster.getPredict(0);
assertEquals(preds.length, 116);
dataset.close();
booster.close();
}

@Test void testUpdateOneIterCustom() throws LGBMException {
LGBMDataset dataset = LGBMDataset.createFromFile("src/test/resources/cancer.csv", "header=true label=name:Classification", null);
LGBMBooster booster = LGBMBooster.create(dataset, "objective=binary label=name:Classification");
int size = dataset.getNumData();
booster.updateOneIterCustom(randomArray(size), randomArray(size));
booster.updateOneIterCustom(randomArray(size), randomArray(size));
booster.updateOneIterCustom(randomArray(size), randomArray(size));
double[] preds = booster.getPredict(0);
assertEquals(preds.length, 116);
dataset.close();
booster.close();
}

@Test void testDoubleClose() throws LGBMException {
LGBMDataset ds = LGBMDataset.createFromMat(new float[]{1.0f, 1.0f, 1.0f, 1.0f}, 2, 2, true, "", null);
LGBMBooster booster = LGBMBooster.create(ds, "");
booster.close();
booster.close();
}

private float[] randomArray(int size) {
float[] result = new float[size];
Random rnd = new Random();
for (int i=0; i<size; i++) {
result[i] = rnd.nextFloat();
}
return result;
}

}
Loading