Skip to content

Commit

Permalink
[R] Throw explicit error when using categorical features in functions…
Browse files Browse the repository at this point in the history
… which don't support them (#11077)

Co-authored-by: Jiaming Yuan <[email protected]>
  • Loading branch information
david-cortes and trivialfis authored Dec 15, 2024
1 parent 5502558 commit dea5753
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 34 deletions.
4 changes: 4 additions & 0 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,10 @@ xgb.best_iteration <- function(bst) {
return(out)
}

xgb.has_categ_features <- function(bst) {
return("c" %in% xgb.feature_types(bst))
}

#' Extract coefficients from linear booster
#'
#' @description
Expand Down
33 changes: 12 additions & 21 deletions R-package/R/xgb.model.dt.tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#'
#' Parse a boosted tree model text dump into a `data.table` structure.
#'
#' Note that this function does not work with models that were fitted to
#' categorical data, and is only applicable to tree-based boosters (not `gblinear`).
#' @param model Object of class `xgb.Booster`. If it contains feature names (they can
#' be set through [setinfo()]), they will be used in the output from this function.
#' @param text Character vector previously generated by the function [xgb.dump()]
#' (called with parameter `with_stats = TRUE`). `text` takes precedence over `model`.
#'
#' If the model contains categorical features, an error will be thrown.
#' @param trees An integer vector of (base-1) tree indices that should be used. The default
#' (`NULL`) uses all trees. Useful, e.g., in multiclass classification to get only
#' the trees of one class.
Expand Down Expand Up @@ -64,14 +66,15 @@
#' ]
#'
#' @export
xgb.model.dt.tree <- function(model = NULL, text = NULL,
trees = NULL, use_int_id = FALSE, ...) {
xgb.model.dt.tree <- function(model, trees = NULL, use_int_id = FALSE, ...) {
check.deprecation(deprecated_dttree_params, match.call(), ...)

if (!inherits(model, "xgb.Booster") && !is.character(text)) {
stop("Either 'model' must be an object of class xgb.Booster\n",
" or 'text' must be a character vector with the result of xgb.dump\n",
" (or NULL if 'model' was provided).")
if (!inherits(model, "xgb.Booster")) {
stop("Either 'model' must be an object of class xgb.Booster")
}

if (xgb.has_categ_features(model)) {
stop("Cannot produce tables for models having categorical features.")
}

if (!is.null(trees)) {
Expand All @@ -89,11 +92,7 @@ xgb.model.dt.tree <- function(model = NULL, text = NULL,
feature_names <- xgb.feature_names(model)
}

from_text <- TRUE
if (is.null(text)) {
text <- xgb.dump(model = model, with_stats = TRUE)
from_text <- FALSE
}
text <- xgb.dump(model = model, with_stats = TRUE)

if (length(text) < 2 || !any(grepl('leaf=(-?\\d+)', text))) {
stop("Non-tree model detected! This function can only be used with tree models.")
Expand Down Expand Up @@ -130,15 +129,7 @@ xgb.model.dt.tree <- function(model = NULL, text = NULL,
branch_rx <- branch_rx_w_names
text_has_feature_names <- TRUE
} else {
# Note: when passing a text dump, it might or might not have feature names,
# but that aspect is unknown from just the text attributes
branch_rx <- branch_rx_nonames
if (from_text) {
if (sum(grepl(branch_rx_w_names, text)) > sum(grepl(branch_rx_nonames, text))) {
branch_rx <- branch_rx_w_names
text_has_feature_names <- TRUE
}
}
}
branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Gain", "Cover")
td[
Expand Down
10 changes: 9 additions & 1 deletion R-package/R/xgb.plot.multi.trees.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#'
#' Visualization of the ensemble of trees as a single collective unit.
#'
#' Note that this function does not work with models that were fitted to
#' categorical data.
#' @details
#' This function tries to capture the complexity of a gradient boosted tree model
#' in a cohesive way by compressing an ensemble of trees into a single tree-graph representation.
Expand Down Expand Up @@ -64,10 +66,16 @@
#' @export
xgb.plot.multi.trees <- function(model, features_keep = 5, plot_width = NULL, plot_height = NULL,
render = TRUE, ...) {
check.deprecation(deprecated_multitrees_params, match.call(), ...)
if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
stop("DiagrammeR is required for xgb.plot.multi.trees")
}
check.deprecation(deprecated_multitrees_params, match.call(), ...)
if (xgb.has_categ_features(model)) {
stop(
"Cannot use 'xgb.plot.multi.trees' for models with categorical features.",
" Try 'xgb.plot.tree' instead."
)
}
tree.matrix <- xgb.model.dt.tree(model = model)

# first number of the path represents the tree, then the following numbers are related to the path to follow
Expand Down
17 changes: 7 additions & 10 deletions R-package/man/xgb.model.dt.tree.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions R-package/man/xgb.plot.multi.trees.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions R-package/tests/testthat/test_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,8 @@ test_that("xgb.importance works with and without feature names", {
importance <- xgb.importance(feature_names = feature.names, model = bst.Tree, trees = trees)

importance_from_dump <- function() {
model_text_dump <- xgb.dump(model = bst.Tree, with_stats = TRUE)
imp <- xgb.model.dt.tree(
text = model_text_dump,
model = bst.Tree,
trees = trees
)[
Feature != "Leaf", .(
Expand Down

0 comments on commit dea5753

Please sign in to comment.