Skip to content

Commit 6f92d47

Browse files
[R-package] Add sparse feature contribution predictions (#5108)
* add predcontrib for sparse inputs * register newly-added function * comments * correct wrong types in test * forcibly take transpose function from Matrix * keep row names, test comparison to dense inputs * workaround for passing test while PR for row names is not merged * Update R-package/R/lgb.Predictor.R Co-authored-by: James Lamb <[email protected]> * Update R-package/R/lgb.Predictor.R Co-authored-by: James Lamb <[email protected]> * Update R-package/R/lgb.Predictor.R Co-authored-by: James Lamb <[email protected]> * proper handling of integer overflow * add test for CSR contrib row names * add more tests for predict(<sparse>, predcontrib=TRUE) * make linter happy * linter * linter * check error messages for bad input shapes * fix regex * hard-coded number of columns in regex for tests Co-authored-by: James Lamb <[email protected]>
1 parent 688f73d commit 6f92d47

File tree

5 files changed

+312
-1
lines changed

5 files changed

+312
-1
lines changed

R-package/NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ export(saveRDS.lgb.Booster)
3737
export(set_field)
3838
export(slice)
3939
import(methods)
40+
importClassesFrom(Matrix,dgCMatrix)
41+
importClassesFrom(Matrix,dgRMatrix)
42+
importClassesFrom(Matrix,dsparseMatrix)
43+
importClassesFrom(Matrix,dsparseVector)
4044
importFrom(Matrix,Matrix)
4145
importFrom(R6,R6Class)
4246
importFrom(data.table,":=")
@@ -51,6 +55,7 @@ importFrom(graphics,barplot)
5155
importFrom(graphics,par)
5256
importFrom(jsonlite,fromJSON)
5357
importFrom(methods,is)
58+
importFrom(methods,new)
5459
importFrom(parallel,detectCores)
5560
importFrom(stats,quantile)
5661
importFrom(utils,modifyList)

R-package/R/lgb.Predictor.R

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
#' @importFrom methods is
1+
#' @importFrom methods is new
2+
#' @importClassesFrom Matrix dsparseMatrix dsparseVector dgCMatrix dgRMatrix
23
#' @importFrom R6 R6Class
34
#' @importFrom utils read.delim
45
Predictor <- R6::R6Class(
@@ -126,6 +127,113 @@ Predictor <- R6::R6Class(
126127
num_row <- nrow(preds)
127128
preds <- as.vector(t(preds))
128129

130+
} else if (predcontrib && inherits(data, c("dsparseMatrix", "dsparseVector"))) {
131+
132+
ncols <- .Call(LGBM_BoosterGetNumFeature_R, private$handle)
133+
ncols_out <- integer(1L)
134+
.Call(LGBM_BoosterGetNumClasses_R, private$handle, ncols_out)
135+
ncols_out <- (ncols + 1L) * max(ncols_out, 1L)
136+
if (is.na(ncols_out)) {
137+
ncols_out <- as.numeric(ncols + 1L) * as.numeric(max(ncols_out, 1L))
138+
}
139+
if (!inherits(data, "dsparseVector") && ncols_out > .Machine$integer.max) {
140+
stop("Resulting matrix of feature contributions is too large for R to handle.")
141+
}
142+
143+
if (inherits(data, "dsparseVector")) {
144+
145+
if (length(data) > ncols) {
146+
stop(sprintf("Model was fitted to data with %d columns, input data has %.0f columns."
147+
, ncols
148+
, length(data)))
149+
}
150+
res <- .Call(
151+
LGBM_BoosterPredictSparseOutput_R
152+
, private$handle
153+
, c(0L, as.integer(length(data@x)))
154+
, data@i - 1L
155+
, data@x
156+
, TRUE
157+
, 1L
158+
, ncols
159+
, start_iteration
160+
, num_iteration
161+
, private$params
162+
)
163+
out <- methods::new("dsparseVector")
164+
out@i <- res$indices + 1L
165+
out@x <- res$data
166+
out@length <- ncols_out
167+
return(out)
168+
169+
} else if (inherits(data, "dgRMatrix")) {
170+
171+
if (ncol(data) > ncols) {
172+
stop(sprintf("Model was fitted to data with %d columns, input data has %.0f columns."
173+
, ncols
174+
, ncol(data)))
175+
}
176+
res <- .Call(
177+
LGBM_BoosterPredictSparseOutput_R
178+
, private$handle
179+
, data@p
180+
, data@j
181+
, data@x
182+
, TRUE
183+
, nrow(data)
184+
, ncols
185+
, start_iteration
186+
, num_iteration
187+
, private$params
188+
)
189+
out <- methods::new("dgRMatrix")
190+
out@p <- res$indptr
191+
out@j <- res$indices
192+
out@x <- res$data
193+
out@Dim <- as.integer(c(nrow(data), ncols_out))
194+
195+
} else if (inherits(data, "dgCMatrix")) {
196+
197+
if (ncol(data) != ncols) {
198+
stop(sprintf("Model was fitted to data with %d columns, input data has %.0f columns."
199+
, ncols
200+
, ncol(data)))
201+
}
202+
res <- .Call(
203+
LGBM_BoosterPredictSparseOutput_R
204+
, private$handle
205+
, data@p
206+
, data@i
207+
, data@x
208+
, FALSE
209+
, nrow(data)
210+
, ncols
211+
, start_iteration
212+
, num_iteration
213+
, private$params
214+
)
215+
out <- methods::new("dgCMatrix")
216+
out@p <- res$indptr
217+
out@i <- res$indices
218+
out@x <- res$data
219+
out@Dim <- as.integer(c(nrow(data), length(res$indptr) - 1L))
220+
221+
} else {
222+
223+
stop(sprintf("Predictions on sparse inputs are only allowed for '%s', '%s', '%s' - got: %s"
224+
, "dsparseVector"
225+
, "dgRMatrix"
226+
, "dgCMatrix"
227+
, paste(class(data)
228+
, collapse = ", ")))
229+
230+
}
231+
232+
if (NROW(row.names(data))) {
233+
out@Dimnames[[1L]] <- row.names(data)
234+
}
235+
return(out)
236+
129237
} else {
130238

131239
# Not a file, we need to predict from R object

R-package/src/lightgbm_R.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ SEXP wrapped_R_raw(void *len) {
6565
return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
6666
}
6767

68+
SEXP wrapped_R_int(void *len) {
69+
return Rf_allocVector(INTSXP, *(reinterpret_cast<R_xlen_t*>(len)));
70+
}
71+
72+
SEXP wrapped_R_real(void *len) {
73+
return Rf_allocVector(REALSXP, *(reinterpret_cast<R_xlen_t*>(len)));
74+
}
75+
6876
SEXP wrapped_Rf_mkChar(void *txt) {
6977
return Rf_mkChar(reinterpret_cast<char*>(txt));
7078
}
@@ -84,6 +92,14 @@ SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
8492
return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
8593
}
8694

95+
SEXP safe_R_int(R_xlen_t len, SEXP *cont_token) {
96+
return R_UnwindProtect(wrapped_R_int, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
97+
}
98+
99+
SEXP safe_R_real(R_xlen_t len, SEXP *cont_token) {
100+
return R_UnwindProtect(wrapped_R_real, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
101+
}
102+
87103
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
88104
return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
89105
}
@@ -851,6 +867,76 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
851867
R_API_END();
852868
}
853869

870+
struct SparseOutputPointers {
871+
void* indptr;
872+
int32_t* indices;
873+
void* data;
874+
int indptr_type;
875+
int data_type;
876+
SparseOutputPointers(void* indptr, int32_t* indices, void* data)
877+
: indptr(indptr), indices(indices), data(data) {}
878+
};
879+
880+
void delete_SparseOutputPointers(SparseOutputPointers *ptr) {
881+
LGBM_BoosterFreePredictSparse(ptr->indptr, ptr->indices, ptr->data, C_API_DTYPE_INT32, C_API_DTYPE_FLOAT64);
882+
delete ptr;
883+
}
884+
885+
SEXP LGBM_BoosterPredictSparseOutput_R(SEXP handle,
886+
SEXP indptr,
887+
SEXP indices,
888+
SEXP data,
889+
SEXP is_csr,
890+
SEXP nrows,
891+
SEXP ncols,
892+
SEXP start_iteration,
893+
SEXP num_iteration,
894+
SEXP parameter) {
895+
SEXP cont_token = PROTECT(R_MakeUnwindCont());
896+
R_API_BEGIN();
897+
_AssertBoosterHandleNotNull(handle);
898+
const char* out_names[] = {"indptr", "indices", "data", ""};
899+
SEXP out = PROTECT(Rf_mkNamed(VECSXP, out_names));
900+
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
901+
902+
int64_t out_len[2];
903+
void *out_indptr;
904+
int32_t *out_indices;
905+
void *out_data;
906+
907+
CHECK_CALL(LGBM_BoosterPredictSparseOutput(R_ExternalPtrAddr(handle),
908+
INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
909+
REAL(data), C_API_DTYPE_FLOAT64,
910+
Rf_xlength(indptr), Rf_xlength(data),
911+
Rf_asLogical(is_csr)? Rf_asInteger(ncols) : Rf_asInteger(nrows),
912+
C_API_PREDICT_CONTRIB, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
913+
parameter_ptr,
914+
Rf_asLogical(is_csr)? C_API_MATRIX_TYPE_CSR : C_API_MATRIX_TYPE_CSC,
915+
out_len, &out_indptr, &out_indices, &out_data));
916+
917+
std::unique_ptr<SparseOutputPointers, decltype(&delete_SparseOutputPointers)> pointers_struct = {
918+
new SparseOutputPointers(
919+
out_indptr,
920+
out_indices,
921+
out_data),
922+
&delete_SparseOutputPointers
923+
};
924+
925+
SEXP out_indptr_R = safe_R_int(out_len[1], &cont_token);
926+
SET_VECTOR_ELT(out, 0, out_indptr_R);
927+
SEXP out_indices_R = safe_R_int(out_len[0], &cont_token);
928+
SET_VECTOR_ELT(out, 1, out_indices_R);
929+
SEXP out_data_R = safe_R_real(out_len[0], &cont_token);
930+
SET_VECTOR_ELT(out, 2, out_data_R);
931+
std::memcpy(INTEGER(out_indptr_R), out_indptr, out_len[1]*sizeof(int));
932+
std::memcpy(INTEGER(out_indices_R), out_indices, out_len[0]*sizeof(int));
933+
std::memcpy(REAL(out_data_R), out_data, out_len[0]*sizeof(double));
934+
935+
UNPROTECT(3);
936+
return out;
937+
R_API_END();
938+
}
939+
854940
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
855941
SEXP num_iteration,
856942
SEXP feature_importance_type,
@@ -975,6 +1061,7 @@ static const R_CallMethodDef CallEntries[] = {
9751061
{"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8},
9761062
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
9771063
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
1064+
{"LGBM_BoosterPredictSparseOutput_R", (DL_FUNC) &LGBM_BoosterPredictSparseOutput_R, 10},
9781065
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
9791066
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
9801067
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},

R-package/src/lightgbm_R.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,35 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMat_R(
574574
SEXP out_result
575575
);
576576

577+
/*!
578+
* \brief make feature contribution prediction for a new Dataset
579+
* \param handle Booster handle
580+
* \param indptr array with the index pointer of the data in CSR or CSC format
581+
* \param indices array with the non-zero indices of the data in CSR or CSC format
582+
* \param data array with the non-zero values of the data in CSR or CSC format
583+
* \param is_csr whether the input data is in CSR format or not (pass FALSE for CSC)
584+
* \param nrows number of rows in the data
585+
* \param ncols number of columns in the data
586+
* \param start_iteration Start index of the iteration to predict
587+
* \param num_iteration number of iteration for prediction, <= 0 means no limit
588+
* \param parameter additional parameters
589+
* \return An R list with entries "indptr", "indices", "data", constituting the
590+
* feature contributions in sparse format, in the same storage order as
591+
* the input data.
592+
*/
593+
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictSparseOutput_R(
594+
SEXP handle,
595+
SEXP indptr,
596+
SEXP indices,
597+
SEXP data,
598+
SEXP is_csr,
599+
SEXP nrows,
600+
SEXP ncols,
601+
SEXP start_iteration,
602+
SEXP num_iteration,
603+
SEXP parameter
604+
);
605+
577606
/*!
578607
* \brief save model into file
579608
* \param handle Booster handle

R-package/tests/testthat/test_Predictor.R

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
library(Matrix)
2+
13
VERBOSITY <- as.integer(
24
Sys.getenv("LIGHTGBM_TEST_VERBOSITY", "-1")
35
)
@@ -116,6 +118,84 @@ test_that("start_iteration works correctly", {
116118
expect_equal(pred_leaf1, pred_leaf2)
117119
})
118120

121+
test_that("Feature contributions from sparse inputs produce sparse outputs", {
122+
data(mtcars)
123+
X <- as.matrix(mtcars[, -1L])
124+
y <- as.numeric(mtcars[, 1L])
125+
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
126+
bst <- lgb.train(
127+
data = dtrain
128+
, obj = "regression"
129+
, nrounds = 5L
130+
, verbose = VERBOSITY
131+
, params = list(min_data_in_leaf = 5L)
132+
)
133+
134+
pred_dense <- predict(bst, X, predcontrib = TRUE)
135+
136+
Xcsc <- as(X, "CsparseMatrix")
137+
pred_csc <- predict(bst, Xcsc, predcontrib = TRUE)
138+
expect_s4_class(pred_csc, "dgCMatrix")
139+
expect_equal(unname(pred_dense), unname(as.matrix(pred_csc)))
140+
141+
Xcsr <- as(X, "RsparseMatrix")
142+
pred_csr <- predict(bst, Xcsr, predcontrib = TRUE)
143+
expect_s4_class(pred_csr, "dgRMatrix")
144+
expect_equal(as(pred_csr, "CsparseMatrix"), pred_csc)
145+
146+
Xspv <- as(X[1L, , drop = FALSE], "sparseVector")
147+
pred_spv <- predict(bst, Xspv, predcontrib = TRUE)
148+
expect_s4_class(pred_spv, "dsparseVector")
149+
expect_equal(Matrix::t(as(pred_spv, "CsparseMatrix")), unname(pred_csc[1L, , drop = FALSE]))
150+
})
151+
152+
test_that("Sparse feature contribution predictions do not take inputs with wrong number of columns", {
153+
data(mtcars)
154+
X <- as.matrix(mtcars[, -1L])
155+
y <- as.numeric(mtcars[, 1L])
156+
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
157+
bst <- lgb.train(
158+
data = dtrain
159+
, obj = "regression"
160+
, nrounds = 5L
161+
, verbose = VERBOSITY
162+
, params = list(min_data_in_leaf = 5L)
163+
)
164+
165+
X_wrong <- X[, c(1L:10L, 1L:10L)]
166+
X_wrong <- as(X_wrong, "CsparseMatrix")
167+
expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 20 columns")
168+
169+
X_wrong <- as(X_wrong, "RsparseMatrix")
170+
expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 20 columns")
171+
172+
X_wrong <- as(X_wrong, "CsparseMatrix")
173+
X_wrong <- X_wrong[, 1L:3L]
174+
expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 3 columns")
175+
})
176+
177+
test_that("Feature contribution predictions do not take non-general CSR or CSC inputs", {
178+
set.seed(123L)
179+
y <- runif(25L)
180+
Dmat <- matrix(runif(625L), nrow = 25L, ncol = 25L)
181+
Dmat <- crossprod(Dmat)
182+
Dmat <- as(Dmat, "symmetricMatrix")
183+
SmatC <- as(Dmat, "sparseMatrix")
184+
SmatR <- as(SmatC, "RsparseMatrix")
185+
186+
dtrain <- lgb.Dataset(as.matrix(Dmat), label = y, params = list(max_bins = 5L))
187+
bst <- lgb.train(
188+
data = dtrain
189+
, obj = "regression"
190+
, nrounds = 5L
191+
, verbose = VERBOSITY
192+
, params = list(min_data_in_leaf = 5L)
193+
)
194+
195+
expect_error(predict(bst, SmatC, predcontrib = TRUE))
196+
expect_error(predict(bst, SmatR, predcontrib = TRUE))
197+
})
198+
119199
test_that("predict() params should override keyword argument for raw-score predictions", {
120200
data(agaricus.train, package = "lightgbm")
121201
X <- agaricus.train$data
@@ -321,6 +401,8 @@ test_that("predict() params should override keyword argument for feature contrib
321401
.expect_has_row_names(pred, Xcsc)
322402
pred <- predict(bst, Xcsc, predcontrib = TRUE)
323403
.expect_has_row_names(pred, Xcsc)
404+
pred <- predict(bst, as(Xcsc, "RsparseMatrix"), predcontrib = TRUE)
405+
.expect_has_row_names(pred, Xcsc)
324406

325407
# sparse matrix without row names
326408
Xcopy <- Xcsc

0 commit comments

Comments
 (0)