Skip to content

Commit 75ac147

Browse files
authored
position bias removal support (#77)
1 parent dbdd9a4 commit 75ac147

File tree

3 files changed

+46
-7
lines changed

3 files changed

+46
-7
lines changed

README.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ So if you do `parameters = "label=some_column_name"`, it will be ignored by the
106106
`createFromMat`
107107
* to set these magical columns, you need to explicitly call `LGBMDataset.setField()` method.
108108
* `label` and `weight` columns [must be](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_DatasetSetField) `float[]`
109-
* `group` column [must be](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_DatasetSetField) `int[]`
109+
* `group` and `position` column [must be](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_DatasetSetField) `int[]`
110110

111111
A full example of loading dataset from a matrix for a cancer dataset:
112112
```java
@@ -200,6 +200,25 @@ for (int i=0; i<10; i++) {
200200
}
201201
```
202202

203+
### Position bias removal
204+
205+
LightGBM 4.1+ can perform a [position-bias aware LTR/LambdaMART](https://lightgbm.readthedocs.io/en/latest/Advanced-Topics.html#support-for-position-bias-treatment) training. To perform it with lightgbm4j you need to explicitly define the `position` field as described in the upstream LightGBM docs:
206+
207+
```java
208+
float[] matrix = new float[] {
209+
// query group 1
210+
1.0f, 2.0f, // doc1
211+
3.0f, 4.0f, // doc2
212+
// query group 2
213+
1.0f, 2.0f, // doc1
214+
3.0f, 4.0f}; // doc2
215+
LGBMDataset ds = LGBMDataset.createFromMat(matrix, 4, 2, true, "", null);
216+
ds.setField("label", new float[] {1.0, 0.0, 1.0, 0.0}); // set relevance labels
217+
ds.setField("group", new int[] {2, 2}); // set document-to-group mapping
218+
ds.setField("position", new int[] {0, 1, 2, 3, 0, 1, 2, 3}); // bias classes
219+
LGBMBooster booster = LGBMBooster.create(ds, "objective=lambdarank");
220+
```
221+
203222
### Custom objectives
204223

205224
LightGBM4j supports using custom objective functions, but it doesn't provide any high-level wrappers as python API does.

src/main/java/io/github/metarank/lightgbm4j/LGBMDataset.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ public void setField(String fieldName, double[] data) throws LGBMException {
198198
* @throws LGBMException
199199
*/
200200
public void setField(String fieldName, int[] data) throws LGBMException {
201-
if (!fieldName.equals("group")) throw new LGBMException("only group field can be int[]");
201+
if (fieldName.equals("label")) throw new LGBMException("label can only be float[]");
202+
if (fieldName.equals("weight")) throw new LGBMException("weight can only be float[]");
202203
SWIGTYPE_p_int dataBuffer = new_intArray(data.length);
203204
for (int i = 0; i < data.length; i++) {
204205
intArray_setitem(dataBuffer, i, data[i]);

src/test/java/io/github/metarank/lightgbm4j/RankingIntegrationTest.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,23 @@
1414
public class RankingIntegrationTest {
1515
@Test
1616
public void testLetor() throws LGBMException, IOException {
17-
LGBMDataset train = datasetFromResource("/mq2008/train.txt.gz", null);
18-
LGBMDataset test = datasetFromResource("/mq2008/test.txt.gz", train);
17+
LGBMDataset train = datasetFromResource("/mq2008/train.txt.gz", null, false);
18+
LGBMDataset test = datasetFromResource("/mq2008/test.txt.gz", train, false);
19+
trainBooster(train, test);
20+
train.close();
21+
test.close();
22+
}
23+
24+
@Test
25+
public void testLetorPosition() throws LGBMException, IOException {
26+
LGBMDataset train = datasetFromResource("/mq2008/train.txt.gz", null, true);
27+
LGBMDataset test = datasetFromResource("/mq2008/test.txt.gz", train, true);
28+
trainBooster(train, test);
29+
train.close();
30+
test.close();
31+
}
32+
33+
public void trainBooster(LGBMDataset train, LGBMDataset test) throws LGBMException, IOException {
1934
LGBMBooster booster = LGBMBooster.create(train, "objective=lambdarank metric=ndcg lambdarank_truncation_level=10 max_depth=5 learning_rate=0.1 num_leaves=8");
2035
booster.addValidData(test);
2136
for (int i=0; i<100; i++) {
@@ -30,12 +45,10 @@ public void testLetor() throws LGBMException, IOException {
3045
assertTrue(names.length > 0);
3146
assertTrue(weights.length > 0);
3247
booster.close();
33-
train.close();
34-
test.close();
3548
}
3649

3750

38-
private static LGBMDataset datasetFromResource(String file, LGBMDataset parent) throws LGBMException, IOException {
51+
private static LGBMDataset datasetFromResource(String file, LGBMDataset parent, boolean withPosition) throws LGBMException, IOException {
3952
BufferedReader reader = new BufferedReader(new InputStreamReader(new GZIPInputStream(RankingIntegrationTest.class.getResourceAsStream(file))));
4053
ArrayList<String> lines = reader.lines().map(line -> {
4154
int commentIndex = line.indexOf('#');
@@ -65,18 +78,21 @@ private static LGBMDataset datasetFromResource(String file, LGBMDataset parent)
6578
double[] features = new double[maxFeatureId * rows];
6679
float[] labels = new float[rows];
6780
int[] groups = new int[queries];
81+
int[] positions = new int[rows];
6882
String[] featureNames = new String[maxFeatureId];
6983
for (int i=1; i <= maxFeatureId; i++) {
7084
featureNames[i-1] = "f"+i;
7185
}
7286
int lastGroup = Integer.MIN_VALUE;
7387
int lastCount = 0;
7488
int groupIndex = 0;
89+
int position = 0;
7590
for (int row = 0; row < rows; row++) {
7691
String line = lines.get(row);
7792
String[] tokens = line.split(" ");
7893
float label = Float.parseFloat(tokens[0]);
7994
labels[row] = label;
95+
positions[row] = position;
8096
int group = Integer.parseInt(tokens[1].split(":")[1]);
8197
if (group != lastGroup) {
8298
// next query
@@ -87,6 +103,7 @@ private static LGBMDataset datasetFromResource(String file, LGBMDataset parent)
87103
}
88104
lastGroup = group;
89105
lastCount = 1;
106+
position = 0;
90107
} else {
91108
lastCount++;
92109
}
@@ -97,6 +114,7 @@ private static LGBMDataset datasetFromResource(String file, LGBMDataset parent)
97114
double value = Double.parseDouble(feature[1]);
98115
features[row * maxFeatureId + id] = value;
99116
}
117+
position++;
100118
}
101119
groups[groupIndex] = lastCount;
102120

@@ -106,6 +124,7 @@ private static LGBMDataset datasetFromResource(String file, LGBMDataset parent)
106124
dataset.setFeatureNames(featureNames);
107125
dataset.setField("label", labels);
108126
dataset.setField("group", groups);
127+
if (withPosition) dataset.setField("position", positions);
109128
return dataset;
110129
}
111130

0 commit comments

Comments
 (0)