Skip to content

Commit 4198c16

Browse files
authored
add ranking integration test (#57)
1 parent 7bf02f7 commit 4198c16

File tree

4 files changed

+119
-0
lines changed

4 files changed

+119
-0
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
*.so filter=lfs diff=lfs merge=lfs -text
44
*.dll filter=lfs diff=lfs merge=lfs -text
55
*.dylib filter=lfs diff=lfs merge=lfs -text
6+
*.gz filter=lfs diff=lfs merge=lfs -text
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package io.github.metarank.lightgbm4j;
2+
3+
import org.junit.jupiter.api.Test;
4+
5+
import java.io.BufferedReader;
6+
import java.io.IOException;
7+
import java.io.InputStreamReader;
8+
import java.util.*;
9+
import java.util.stream.Collectors;
10+
import java.util.zip.GZIPInputStream;
11+
12+
import static org.junit.jupiter.api.Assertions.assertTrue;
13+
14+
public class RankingIntegrationTest {
15+
@Test
16+
public void testLetor() throws LGBMException, IOException {
17+
LGBMDataset train = datasetFromResource("/mq2008/train.txt.gz", null);
18+
LGBMDataset test = datasetFromResource("/mq2008/test.txt.gz", train);
19+
LGBMBooster booster = LGBMBooster.create(train, "objective=lambdarank metric=ndcg lambdarank_truncation_level=10 max_depth=5 learning_rate=0.1 num_leaves=8");
20+
booster.addValidData(test);
21+
for (int i=0; i<100; i++) {
22+
booster.updateOneIter();
23+
double[] eval1 = booster.getEval(0);
24+
double[] eval2 = booster.getEval(1);
25+
System.out.println("train " + eval1[0] + " test " + eval2[0]);
26+
assertTrue(eval1[0] > 0.5);
27+
}
28+
String[] names = booster.getFeatureNames();
29+
double[] weights = booster.featureImportance(0, LGBMBooster.FeatureImportanceType.GAIN);
30+
assertTrue(names.length > 0);
31+
assertTrue(weights.length > 0);
32+
booster.close();
33+
train.close();
34+
test.close();
35+
}
36+
37+
38+
private static LGBMDataset datasetFromResource(String file, LGBMDataset parent) throws LGBMException, IOException {
39+
BufferedReader reader = new BufferedReader(new InputStreamReader(new GZIPInputStream(RankingIntegrationTest.class.getResourceAsStream(file))));
40+
ArrayList<String> lines = reader.lines().map(line -> {
41+
int commentIndex = line.indexOf('#');
42+
if (commentIndex >= 0) {
43+
return line.substring(0, commentIndex);
44+
} else {
45+
return line;
46+
}
47+
}).collect(Collectors.toCollection(ArrayList::new));
48+
int maxFeatureId = 0; // features are 1-indexed!
49+
Set<Integer> queriesSet = new HashSet<>();
50+
for (String line: lines) {
51+
String[] tokens = line.split(" ");
52+
int group = Integer.parseInt(tokens[1].split(":")[1]);
53+
queriesSet.add(group);
54+
for (int i = 2; i < tokens.length; i++) {
55+
String[] parts = tokens[i].split(":");
56+
int featureId = Integer.parseInt(parts[0]);
57+
if (featureId > maxFeatureId) {
58+
maxFeatureId = featureId;
59+
}
60+
}
61+
}
62+
63+
int rows = lines.size();
64+
int queries = queriesSet.size();
65+
double[] features = new double[maxFeatureId * rows];
66+
float[] labels = new float[rows];
67+
int[] groups = new int[queries];
68+
String[] featureNames = new String[maxFeatureId];
69+
for (int i=1; i <= maxFeatureId; i++) {
70+
featureNames[i-1] = "f"+i;
71+
}
72+
int lastGroup = Integer.MIN_VALUE;
73+
int lastCount = 0;
74+
int groupIndex = 0;
75+
for (int row = 0; row < rows; row++) {
76+
String line = lines.get(row);
77+
String[] tokens = line.split(" ");
78+
float label = Float.parseFloat(tokens[0]);
79+
labels[row] = label;
80+
int group = Integer.parseInt(tokens[1].split(":")[1]);
81+
if (group != lastGroup) {
82+
// next query
83+
if (lastCount > 0) {
84+
// so it's not the first one
85+
groups[groupIndex] = lastCount;
86+
groupIndex++;
87+
}
88+
lastGroup = group;
89+
lastCount = 1;
90+
} else {
91+
lastCount++;
92+
}
93+
94+
for (int i=2; i < tokens.length; i++) {
95+
String[] feature = tokens[i].split(":");
96+
int id = Integer.parseInt(feature[0]) - 1;
97+
double value = Double.parseDouble(feature[1]);
98+
features[row * maxFeatureId + id] = value;
99+
}
100+
}
101+
groups[groupIndex] = lastCount;
102+
103+
104+
reader.close();
105+
LGBMDataset dataset = LGBMDataset.createFromMat(features, rows, maxFeatureId, true, "", parent);
106+
dataset.setFeatureNames(featureNames);
107+
dataset.setField("label", labels);
108+
dataset.setField("group", groups);
109+
return dataset;
110+
}
111+
112+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:d12b6a9ba0577460540a6c27630866b880d30f0b33431e27cf21d5b6efc80f1e
3+
size 345457
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:6f5707c2cb822750c16be00edf807c0eb41776020a0c0eb8b4f36fdead63dcd7
3+
size 1126873

0 commit comments

Comments
 (0)