Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R] xgb.importance: base-1 indexing and removal of deprecated #11099

Merged
merged 1 commit into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions R-package/R/xgb.importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,11 @@
#' @param feature_names Character vector used to overwrite the feature names
#' of the model. The default is `NULL` (use original feature names).
#' @param model Object of class `xgb.Booster`.
#' @param trees An integer vector of tree indices that should be included
#' @param trees An integer vector of (base-1) tree indices that should be included
#' into the importance calculation (only for the "gbtree" booster).
#' The default (`NULL`) parses all trees.
#' It could be useful, e.g., in multiclass classification to get feature importances
#' for each class separately. *Important*: the tree index in XGBoost models
#' is zero-based (e.g., use `trees = 0:4` for the first five trees).
#' @param data Deprecated.
#' @param label Deprecated.
#' @param target Deprecated.
#' for each class separately.
#' @return A `data.table` with the following columns:
#'
#' For a tree model:
Expand Down Expand Up @@ -94,14 +90,14 @@
#'
#' # inspect importances separately for each class:
#' xgb.importance(
#' model = mbst, trees = seq(from = 0, by = nclass, length.out = nrounds)
#' )
#' xgb.importance(
#' model = mbst, trees = seq(from = 1, by = nclass, length.out = nrounds)
#' )
#' xgb.importance(
#' model = mbst, trees = seq(from = 2, by = nclass, length.out = nrounds)
#' )
#' xgb.importance(
#' model = mbst, trees = seq(from = 3, by = nclass, length.out = nrounds)
#' )
#'
#' # multiclass classification using "gblinear":
#' mbst <- xgb.train(
Expand All @@ -122,15 +118,21 @@
#' xgb.importance(model = mbst)
#'
#' @export
xgb.importance <- function(model = NULL, feature_names = getinfo(model, "feature_name"), trees = NULL,
data = NULL, label = NULL, target = NULL) {

if (!(is.null(data) && is.null(label) && is.null(target)))
warning("xgb.importance: parameters 'data', 'label' and 'target' are deprecated")
xgb.importance <- function(model = NULL, feature_names = getinfo(model, "feature_name"), trees = NULL) {

if (!(is.null(feature_names) || is.character(feature_names)))
stop("feature_names: Has to be a character vector")

if (!is.null(trees)) {
if (!is.vector(trees)) {
stop("'trees' must be a vector of tree indices.")
}
trees <- trees - 1L
if (anyNA(trees)) {
stop("Passed invalid tree indices.")
}
}

handle <- xgb.get.handle(model)
if (xgb.booster_type(model) == "gblinear") {
args <- list(importance_type = "weight", feature_names = feature_names)
Expand Down
15 changes: 10 additions & 5 deletions R-package/R/xgb.model.dt.tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
#' be set through [setinfo()]), they will be used in the output from this function.
#' @param text Character vector previously generated by the function [xgb.dump()]
#' (called with parameter `with_stats = TRUE`). `text` takes precedence over `model`.
#' @param trees An integer vector of tree indices that should be used. The default
#' @param trees An integer vector of (base-1) tree indices that should be used. The default
#' (`NULL`) uses all trees. Useful, e.g., in multiclass classification to get only
#' the trees of one class. *Important*: the tree index in XGBoost models
#' is zero-based (e.g., use `trees = 0:4` for the first five trees).
#' the trees of one class.
#' @param use_int_id A logical flag indicating whether nodes in columns "Yes", "No", and
#' "Missing" should be represented as integers (when `TRUE`) or as "Tree-Node"
#' character strings (when `FALSE`, default).
Expand Down Expand Up @@ -75,8 +74,14 @@ xgb.model.dt.tree <- function(model = NULL, text = NULL,
" (or NULL if 'model' was provided).")
}

if (!(is.null(trees) || is.numeric(trees))) {
stop("trees: must be a vector of integers.")
if (!is.null(trees)) {
if (!is.vector(trees) || (!is.numeric(trees) && !is.integer(trees))) {
stop("trees: must be a vector of integers.")
}
trees <- trees - 1L
if (anyNA(trees) || min(trees) < 0) {
stop("Passed invalid tree indices.")
}
}

feature_names <- NULL
Expand Down
2 changes: 1 addition & 1 deletion R-package/R/xgb.plot.shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
#' num_class = nclass
#' )
#' )
#' trees0 <- seq(from = 0, by = nclass, length.out = nrounds)
#' trees0 <- seq(from = 1, by = nclass, length.out = nrounds)
#' col <- rgb(0, 0, 1, 0.5)
#'
#' xgb.plot.shap(
Expand Down
22 changes: 6 additions & 16 deletions R-package/man/xgb.importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions R-package/man/xgb.model.dt.tree.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R-package/man/xgb.plot.shap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R-package/tests/testthat/test_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ test_that("xgb.importance works with and without feature names", {
imp.Tree <- xgb.importance(model = mbst.Tree)
expect_equal(dim(imp.Tree), c(4, 4))

trees <- seq(from = 0, by = 2, length.out = 2)
trees <- seq(from = 1, by = 2, length.out = 2)
importance <- xgb.importance(feature_names = feature.names, model = bst.Tree, trees = trees)

importance_from_dump <- function() {
Expand Down
Loading