@@ -65,6 +65,14 @@ SEXP wrapped_R_raw(void *len) {
65
65
return Rf_allocVector (RAWSXP, *(reinterpret_cast <R_xlen_t*>(len)));
66
66
}
67
67
68
+ SEXP wrapped_R_int (void *len) {
69
+ return Rf_allocVector (INTSXP, *(reinterpret_cast <R_xlen_t*>(len)));
70
+ }
71
+
72
+ SEXP wrapped_R_real (void *len) {
73
+ return Rf_allocVector (REALSXP, *(reinterpret_cast <R_xlen_t*>(len)));
74
+ }
75
+
68
76
SEXP wrapped_Rf_mkChar (void *txt) {
69
77
return Rf_mkChar (reinterpret_cast <char *>(txt));
70
78
}
@@ -84,6 +92,14 @@ SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
84
92
return R_UnwindProtect (wrapped_R_raw, reinterpret_cast <void *>(&len), throw_R_memerr, cont_token, *cont_token);
85
93
}
86
94
95
+ SEXP safe_R_int (R_xlen_t len, SEXP *cont_token) {
96
+ return R_UnwindProtect (wrapped_R_int, reinterpret_cast <void *>(&len), throw_R_memerr, cont_token, *cont_token);
97
+ }
98
+
99
+ SEXP safe_R_real (R_xlen_t len, SEXP *cont_token) {
100
+ return R_UnwindProtect (wrapped_R_real, reinterpret_cast <void *>(&len), throw_R_memerr, cont_token, *cont_token);
101
+ }
102
+
87
103
SEXP safe_R_mkChar (char *txt, SEXP *cont_token) {
88
104
return R_UnwindProtect (wrapped_Rf_mkChar, reinterpret_cast <void *>(txt), throw_R_memerr, cont_token, *cont_token);
89
105
}
@@ -851,6 +867,76 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
851
867
R_API_END ();
852
868
}
853
869
870
+ struct SparseOutputPointers {
871
+ void * indptr;
872
+ int32_t * indices;
873
+ void * data;
874
+ int indptr_type;
875
+ int data_type;
876
+ SparseOutputPointers (void * indptr, int32_t * indices, void * data)
877
+ : indptr(indptr), indices(indices), data(data) {}
878
+ };
879
+
880
+ void delete_SparseOutputPointers (SparseOutputPointers *ptr) {
881
+ LGBM_BoosterFreePredictSparse (ptr->indptr , ptr->indices , ptr->data , C_API_DTYPE_INT32, C_API_DTYPE_FLOAT64);
882
+ delete ptr;
883
+ }
884
+
885
+ SEXP LGBM_BoosterPredictSparseOutput_R (SEXP handle,
886
+ SEXP indptr,
887
+ SEXP indices,
888
+ SEXP data,
889
+ SEXP is_csr,
890
+ SEXP nrows,
891
+ SEXP ncols,
892
+ SEXP start_iteration,
893
+ SEXP num_iteration,
894
+ SEXP parameter) {
895
+ SEXP cont_token = PROTECT (R_MakeUnwindCont ());
896
+ R_API_BEGIN ();
897
+ _AssertBoosterHandleNotNull (handle);
898
+ const char * out_names[] = {" indptr" , " indices" , " data" , " " };
899
+ SEXP out = PROTECT (Rf_mkNamed (VECSXP, out_names));
900
+ const char * parameter_ptr = CHAR (PROTECT (Rf_asChar (parameter)));
901
+
902
+ int64_t out_len[2 ];
903
+ void *out_indptr;
904
+ int32_t *out_indices;
905
+ void *out_data;
906
+
907
+ CHECK_CALL (LGBM_BoosterPredictSparseOutput (R_ExternalPtrAddr (handle),
908
+ INTEGER (indptr), C_API_DTYPE_INT32, INTEGER (indices),
909
+ REAL (data), C_API_DTYPE_FLOAT64,
910
+ Rf_xlength (indptr), Rf_xlength (data),
911
+ Rf_asLogical (is_csr)? Rf_asInteger (ncols) : Rf_asInteger (nrows),
912
+ C_API_PREDICT_CONTRIB, Rf_asInteger (start_iteration), Rf_asInteger (num_iteration),
913
+ parameter_ptr,
914
+ Rf_asLogical (is_csr)? C_API_MATRIX_TYPE_CSR : C_API_MATRIX_TYPE_CSC,
915
+ out_len, &out_indptr, &out_indices, &out_data));
916
+
917
+ std::unique_ptr<SparseOutputPointers, decltype (&delete_SparseOutputPointers)> pointers_struct = {
918
+ new SparseOutputPointers (
919
+ out_indptr,
920
+ out_indices,
921
+ out_data),
922
+ &delete_SparseOutputPointers
923
+ };
924
+
925
+ SEXP out_indptr_R = safe_R_int (out_len[1 ], &cont_token);
926
+ SET_VECTOR_ELT (out, 0 , out_indptr_R);
927
+ SEXP out_indices_R = safe_R_int (out_len[0 ], &cont_token);
928
+ SET_VECTOR_ELT (out, 1 , out_indices_R);
929
+ SEXP out_data_R = safe_R_real (out_len[0 ], &cont_token);
930
+ SET_VECTOR_ELT (out, 2 , out_data_R);
931
+ std::memcpy (INTEGER (out_indptr_R), out_indptr, out_len[1 ]*sizeof (int ));
932
+ std::memcpy (INTEGER (out_indices_R), out_indices, out_len[0 ]*sizeof (int ));
933
+ std::memcpy (REAL (out_data_R), out_data, out_len[0 ]*sizeof (double ));
934
+
935
+ UNPROTECT (3 );
936
+ return out;
937
+ R_API_END ();
938
+ }
939
+
854
940
SEXP LGBM_BoosterSaveModel_R (SEXP handle,
855
941
SEXP num_iteration,
856
942
SEXP feature_importance_type,
@@ -975,6 +1061,7 @@ static const R_CallMethodDef CallEntries[] = {
975
1061
{" LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8 },
976
1062
{" LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14 },
977
1063
{" LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11 },
1064
+ {" LGBM_BoosterPredictSparseOutput_R" , (DL_FUNC) &LGBM_BoosterPredictSparseOutput_R, 10 },
978
1065
{" LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4 },
979
1066
{" LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3 },
980
1067
{" LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3 },
0 commit comments