Skip to content

Commit

Permalink
modify R for #5540
Browse files Browse the repository at this point in the history
  • Loading branch information
junpeng0715 committed Oct 18, 2022
1 parent b6b5d63 commit 2ccdb39
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
SEXP reference) {
R_API_BEGIN();
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
double* p_mat = REAL(data);
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Expand All @@ -223,7 +223,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
int64_t len = static_cast<int64_t>(Rf_asInteger(len_used_row_indices));
std::vector<int32_t> idxvec(len);
// convert from one-based to zero-based index
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Expand Down Expand Up @@ -335,7 +335,7 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
SEXP num_element) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
int len = Rf_asInteger(num_element);
int64_t len = Rf_asInteger(num_element);
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
if (!strcmp("group", name) || !strcmp("query", name)) {
std::vector<int32_t> vec(len);
Expand All @@ -349,7 +349,7 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
} else {
std::vector<float> vec(len);
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
for (int i = 0; i < len; ++i) {
for (int64_t i = 0; i < len; ++i) {
vec[i] = static_cast<float>(REAL(field_data)[i]);
}
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
Expand All @@ -365,27 +365,27 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
int out_len = 0;
int64_t out_len = 0;
int out_type = 0;
const void* res;
CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
if (!strcmp("group", name) || !strcmp("query", name)) {
auto p_data = reinterpret_cast<const int32_t*>(res);
// convert from boundaries to size
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len - 1; ++i) {
for (int64_t i = 0; i < out_len - 1; ++i) {
INTEGER(field_data)[i] = p_data[i + 1] - p_data[i];
}
} else if (!strcmp("init_score", name)) {
auto p_data = reinterpret_cast<const double*>(res);
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len; ++i) {
for (int64_t i = 0; i < out_len; ++i) {
REAL(field_data)[i] = p_data[i];
}
} else {
auto p_data = reinterpret_cast<const float*>(res);
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len; ++i) {
for (int64_t i = 0; i < out_len; ++i) {
REAL(field_data)[i] = p_data[i];
}
}
Expand Down Expand Up @@ -427,7 +427,7 @@ SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
int nrow;
int64_t nrow;
CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
INTEGER(out)[0] = nrow;
return R_NilValue;
Expand Down Expand Up @@ -957,7 +957,7 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
const double* p_mat = REAL(data);
double* ptr_ret = REAL(out_result);
Expand Down

0 comments on commit 2ccdb39

Please sign in to comment.