1414public class RankingIntegrationTest {
1515 @ Test
1616 public void testLetor () throws LGBMException , IOException {
17- LGBMDataset train = datasetFromResource ("/mq2008/train.txt.gz" , null );
18- LGBMDataset test = datasetFromResource ("/mq2008/test.txt.gz" , train );
17+ LGBMDataset train = datasetFromResource ("/mq2008/train.txt.gz" , null , false );
18+ LGBMDataset test = datasetFromResource ("/mq2008/test.txt.gz" , train , false );
19+ trainBooster (train , test );
20+ train .close ();
21+ test .close ();
22+ }
23+
24+ @ Test
25+ public void testLetorPosition () throws LGBMException , IOException {
26+ LGBMDataset train = datasetFromResource ("/mq2008/train.txt.gz" , null , true );
27+ LGBMDataset test = datasetFromResource ("/mq2008/test.txt.gz" , train , true );
28+ trainBooster (train , test );
29+ train .close ();
30+ test .close ();
31+ }
32+
33+ public void trainBooster (LGBMDataset train , LGBMDataset test ) throws LGBMException , IOException {
1934 LGBMBooster booster = LGBMBooster .create (train , "objective=lambdarank metric=ndcg lambdarank_truncation_level=10 max_depth=5 learning_rate=0.1 num_leaves=8" );
2035 booster .addValidData (test );
2136 for (int i =0 ; i <100 ; i ++) {
@@ -30,12 +45,10 @@ public void testLetor() throws LGBMException, IOException {
3045 assertTrue (names .length > 0 );
3146 assertTrue (weights .length > 0 );
3247 booster .close ();
33- train .close ();
34- test .close ();
3548 }
3649
3750
38- private static LGBMDataset datasetFromResource (String file , LGBMDataset parent ) throws LGBMException , IOException {
51+ private static LGBMDataset datasetFromResource (String file , LGBMDataset parent , boolean withPosition ) throws LGBMException , IOException {
3952 BufferedReader reader = new BufferedReader (new InputStreamReader (new GZIPInputStream (RankingIntegrationTest .class .getResourceAsStream (file ))));
4053 ArrayList <String > lines = reader .lines ().map (line -> {
4154 int commentIndex = line .indexOf ('#' );
@@ -65,18 +78,21 @@ private static LGBMDataset datasetFromResource(String file, LGBMDataset parent)
6578 double [] features = new double [maxFeatureId * rows ];
6679 float [] labels = new float [rows ];
6780 int [] groups = new int [queries ];
81+ int [] positions = new int [rows ];
6882 String [] featureNames = new String [maxFeatureId ];
6983 for (int i =1 ; i <= maxFeatureId ; i ++) {
7084 featureNames [i -1 ] = "f" +i ;
7185 }
7286 int lastGroup = Integer .MIN_VALUE ;
7387 int lastCount = 0 ;
7488 int groupIndex = 0 ;
89+ int position = 0 ;
7590 for (int row = 0 ; row < rows ; row ++) {
7691 String line = lines .get (row );
7792 String [] tokens = line .split (" " );
7893 float label = Float .parseFloat (tokens [0 ]);
7994 labels [row ] = label ;
95+ positions [row ] = position ;
8096 int group = Integer .parseInt (tokens [1 ].split (":" )[1 ]);
8197 if (group != lastGroup ) {
8298 // next query
@@ -87,6 +103,7 @@ private static LGBMDataset datasetFromResource(String file, LGBMDataset parent)
87103 }
88104 lastGroup = group ;
89105 lastCount = 1 ;
106+ position = 0 ;
90107 } else {
91108 lastCount ++;
92109 }
@@ -97,6 +114,7 @@ private static LGBMDataset datasetFromResource(String file, LGBMDataset parent)
97114 double value = Double .parseDouble (feature [1 ]);
98115 features [row * maxFeatureId + id ] = value ;
99116 }
117+ position ++;
100118 }
101119 groups [groupIndex ] = lastCount ;
102120
@@ -106,6 +124,7 @@ private static LGBMDataset datasetFromResource(String file, LGBMDataset parent)
106124 dataset .setFeatureNames (featureNames );
107125 dataset .setField ("label" , labels );
108126 dataset .setField ("group" , groups );
127+ if (withPosition ) dataset .setField ("position" , positions );
109128 return dataset ;
110129 }
111130
0 commit comments