Skip to content

Commit

Permalink
Merge pull request #25 from paq/cleanup
Browse files Browse the repository at this point in the history
Minor refactoring
  • Loading branch information
vaaaaanquish authored Feb 13, 2021
2 parents 612b22c + d4fbadc commit fe6cadc
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 17 deletions.
1 change: 1 addition & 0 deletions lightgbm-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
#![allow(non_snake_case)]
#![allow(clippy::redundant_static_lifetimes)]
#![allow(clippy::missing_safety_doc)]
#![allow(clippy::upper_case_acronyms)]

include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
11 changes: 5 additions & 6 deletions src/booster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use serde_json::Value;

use lightgbm_sys;

use super::{Dataset, Error, Result};
use crate::{Dataset, Error, Result};

/// Core model in LightGBM, containing functions for training, evaluating and predicting.
pub struct Booster {
pub(super) handle: lightgbm_sys::BoosterHandle,
handle: lightgbm_sys::BoosterHandle,
}

impl Booster {
Expand Down Expand Up @@ -159,7 +159,7 @@ impl Booster {

/// Get Feature Names.
pub fn feature_name(&self) -> Result<Vec<String>> {
let num_feature = self.num_feature().unwrap();
let num_feature = self.num_feature()?;
let feature_name_length = 32;
let mut num_feature_names = 0;
let mut out_buffer_len = 0;
Expand Down Expand Up @@ -187,7 +187,7 @@ impl Booster {

// Get Feature Importance
pub fn feature_importance(&self) -> Result<Vec<f64>> {
let num_feature = self.num_feature().unwrap();
let num_feature = self.num_feature()?;
let out_result: Vec<f64> = vec![Default::default(); num_feature as usize];
lgbm_call!(lightgbm_sys::LGBM_BoosterFeatureImportance(
self.handle,
Expand Down Expand Up @@ -231,8 +231,7 @@ mod tests {

fn _train_booster(params: &Value) -> Booster {
let dataset = _read_train_file().unwrap();
let bst = Booster::train(dataset, &params).unwrap();
bst
Booster::train(dataset, &params).unwrap()
}

fn _default_params() -> Value {
Expand Down
10 changes: 5 additions & 5 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use lightgbm_sys;
use std;
use std::ffi::CString;

use super::{Error, Result};
use crate::{Error, Result};

/// Dataset used throughout LightGBM for training.
///
Expand Down Expand Up @@ -31,13 +31,13 @@ use super::{Error, Result};
/// let dataset = Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train").unwrap();
/// ```
pub struct Dataset {
pub(super) handle: lightgbm_sys::DatasetHandle,
pub(crate) handle: lightgbm_sys::DatasetHandle,
}

#[link(name = "c")]
impl Dataset {
fn new(handle: lightgbm_sys::DatasetHandle) -> Self {
Dataset { handle }
Self { handle }
}

/// Create a new `Dataset` from dense array in row-major order.
Expand Down Expand Up @@ -82,7 +82,7 @@ impl Dataset {
lightgbm_sys::C_API_DTYPE_FLOAT32 as i32
))?;

Ok(Dataset::new(handle))
Ok(Self::new(handle))
}

/// Create a new `Dataset` from file.
Expand Down Expand Up @@ -116,7 +116,7 @@ impl Dataset {
&mut handle
))?;

Ok(Dataset::new(handle))
Ok(Self::new(handle))
}
}

Expand Down
9 changes: 3 additions & 6 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct Error {

impl Error {
pub(crate) fn new<S: Into<String>>(desc: S) -> Self {
Error { desc: desc.into() }
Self { desc: desc.into() }
}

/// Check the return value from an LightGBM FFI call, and return the last error message on error.
Expand All @@ -28,11 +28,8 @@ impl Error {
pub(crate) fn check_return_value(ret_val: i32) -> Result<()> {
match ret_val {
0 => Ok(()),
-1 => Err(Error::from_lightgbm()),
_ => panic!(format!(
"unexpected return value '{}', expected 0 or -1",
ret_val
)),
-1 => Err(Self::from_lightgbm()),
_ => panic!("unexpected return value '{}', expected 0 or -1", ret_val),
}
}

Expand Down

0 comments on commit fe6cadc

Please sign in to comment.