Skip to content

Commit

Permalink
explicit error for unsupported categorical features
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes committed Dec 7, 2024
1 parent 4b2001e commit 1cb28ec
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 8 deletions.
4 changes: 4 additions & 0 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,10 @@ xgb.best_iteration <- function(bst) {
return(out)
}

xgb.has_categ_features <- function(bst) {
return("c" %in% xgb.feature_types(bst))
}

#' Extract coefficients from linear booster
#'
#' @description
Expand Down
8 changes: 8 additions & 0 deletions R-package/R/xgb.model.dt.tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
#'
#' Parse a boosted tree model text dump into a `data.table` structure.
#'
#' Note that this function does not work with models that were fitted to
#' categorical data.
#' @param model Object of class `xgb.Booster`. If it contains feature names (they can
#' be set through [setinfo()]), they will be used in the output from this function.
#'
#' If the model contains categorical features, an error will be thrown.
#' @param text Character vector previously generated by the function [xgb.dump()]
#' (called with parameter `with_stats = TRUE`). `text` takes precedence over `model`.
#' @param trees An integer vector of tree indices that should be used. The default
Expand Down Expand Up @@ -68,6 +72,10 @@ xgb.model.dt.tree <- function(model = NULL, text = NULL,
trees = NULL, use_int_id = FALSE, ...) {
check.deprecation(...)

if (xgb.has_categ_features(model)) {
stop("Cannot produce tables for models having categorical features.")
}

if (!inherits(model, "xgb.Booster") && !is.character(text)) {
stop("Either 'model' must be an object of class xgb.Booster\n",
" or 'text' must be a character vector with the result of xgb.dump\n",
Expand Down
8 changes: 8 additions & 0 deletions R-package/R/xgb.plot.multi.trees.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#'
#' Visualization of the ensemble of trees as a single collective unit.
#'
#' Note that this function does not work with models that were fitted to
#' categorical data.
#' @details
#' This function tries to capture the complexity of a gradient boosted tree model
#' in a cohesive way by compressing an ensemble of trees into a single tree-graph representation.
Expand Down Expand Up @@ -66,6 +68,12 @@ xgb.plot.multi.trees <- function(model, features_keep = 5, plot_width = NULL, pl
stop("DiagrammeR is required for xgb.plot.multi.trees")
}
check.deprecation(...)
if (xgb.has_categ_features(model)) {
stop(
"Cannot use 'xgb.plot.multi.trees' for models with categorical features.",
" Try 'xgb.plot.tree' instead."
)
}
tree.matrix <- xgb.model.dt.tree(model = model)

# first number of the path represents the tree, then the following numbers are related to the path to follow
Expand Down
16 changes: 8 additions & 8 deletions R-package/man/xgb.model.dt.tree.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions R-package/man/xgb.plot.multi.trees.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1cb28ec

Please sign in to comment.