diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 6ffaa299b500..011c73ab5892 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -1055,6 +1055,10 @@ xgb.best_iteration <- function(bst) { return(out) } +xgb.has_categ_features <- function(bst) { + return("c" %in% xgb.feature_types(bst)) +} + #' Extract coefficients from linear booster #' #' @description diff --git a/R-package/R/xgb.model.dt.tree.R b/R-package/R/xgb.model.dt.tree.R index 789737d4c701..cb7a51a1e555 100644 --- a/R-package/R/xgb.model.dt.tree.R +++ b/R-package/R/xgb.model.dt.tree.R @@ -2,10 +2,12 @@ #' #' Parse a boosted tree model text dump into a `data.table` structure. #' +#' Note that this function does not work with models that were fitted to +#' categorical data, and is only applicable to tree-based boosters (not `gblinear`). #' @param model Object of class `xgb.Booster`. If it contains feature names (they can #' be set through [setinfo()]), they will be used in the output from this function. -#' @param text Character vector previously generated by the function [xgb.dump()] -#' (called with parameter `with_stats = TRUE`). `text` takes precedence over `model`. +#' +#' If the model contains categorical features, an error will be thrown. #' @param trees An integer vector of (base-1) tree indices that should be used. The default #' (`NULL`) uses all trees. Useful, e.g., in multiclass classification to get only #' the trees of one class. @@ -64,14 +66,15 @@ #' ] #' #' @export -xgb.model.dt.tree <- function(model = NULL, text = NULL, - trees = NULL, use_int_id = FALSE, ...) { +xgb.model.dt.tree <- function(model, trees = NULL, use_int_id = FALSE, ...) { check.deprecation(deprecated_dttree_params, match.call(), ...) - if (!inherits(model, "xgb.Booster") && !is.character(text)) { - stop("Either 'model' must be an object of class xgb.Booster\n", - " or 'text' must be a character vector with the result of xgb.dump\n", - " (or NULL if 'model' was provided).") + if (!inherits(model, "xgb.Booster")) { + stop("Either 'model' must be an object of class xgb.Booster") + } + + if (xgb.has_categ_features(model)) { + stop("Cannot produce tables for models having categorical features.") } if (!is.null(trees)) { @@ -89,11 +92,7 @@ xgb.model.dt.tree <- function(model = NULL, text = NULL, feature_names <- xgb.feature_names(model) } - from_text <- TRUE - if (is.null(text)) { - text <- xgb.dump(model = model, with_stats = TRUE) - from_text <- FALSE - } + text <- xgb.dump(model = model, with_stats = TRUE) if (length(text) < 2 || !any(grepl('leaf=(-?\\d+)', text))) { stop("Non-tree model detected! This function can only be used with tree models.") @@ -130,15 +129,7 @@ xgb.model.dt.tree <- function(model = NULL, text = NULL, branch_rx <- branch_rx_w_names text_has_feature_names <- TRUE } else { - # Note: when passing a text dump, it might or might not have feature names, - # but that aspect is unknown from just the text attributes branch_rx <- branch_rx_nonames - if (from_text) { - if (sum(grepl(branch_rx_w_names, text)) > sum(grepl(branch_rx_nonames, text))) { - branch_rx <- branch_rx_w_names - text_has_feature_names <- TRUE - } - } } branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Gain", "Cover") td[ diff --git a/R-package/R/xgb.plot.multi.trees.R b/R-package/R/xgb.plot.multi.trees.R index 1c57dd84babd..e67a4ec7eab9 100644 --- a/R-package/R/xgb.plot.multi.trees.R +++ b/R-package/R/xgb.plot.multi.trees.R @@ -2,6 +2,8 @@ #' #' Visualization of the ensemble of trees as a single collective unit. #' +#' Note that this function does not work with models that were fitted to +#' categorical data. #' @details #' This function tries to capture the complexity of a gradient boosted tree model #' in a cohesive way by compressing an ensemble of trees into a single tree-graph representation. @@ -64,10 +66,16 @@ #' @export xgb.plot.multi.trees <- function(model, features_keep = 5, plot_width = NULL, plot_height = NULL, render = TRUE, ...) { + check.deprecation(deprecated_multitrees_params, match.call(), ...) if (!requireNamespace("DiagrammeR", quietly = TRUE)) { stop("DiagrammeR is required for xgb.plot.multi.trees") } - check.deprecation(deprecated_multitrees_params, match.call(), ...) + if (xgb.has_categ_features(model)) { + stop( + "Cannot use 'xgb.plot.multi.trees' for models with categorical features.", + " Try 'xgb.plot.tree' instead." + ) + } tree.matrix <- xgb.model.dt.tree(model = model) # first number of the path represents the tree, then the following numbers are related to the path to follow diff --git a/R-package/man/xgb.model.dt.tree.Rd b/R-package/man/xgb.model.dt.tree.Rd index 424552e490cd..8495cb2ac77a 100644 --- a/R-package/man/xgb.model.dt.tree.Rd +++ b/R-package/man/xgb.model.dt.tree.Rd @@ -4,20 +4,13 @@ \alias{xgb.model.dt.tree} \title{Parse model text dump} \usage{ -xgb.model.dt.tree( - model = NULL, - text = NULL, - trees = NULL, - use_int_id = FALSE, - ... -) +xgb.model.dt.tree(model, trees = NULL, use_int_id = FALSE, ...) } \arguments{ \item{model}{Object of class \code{xgb.Booster}. If it contains feature names (they can -be set through \code{\link[=setinfo]{setinfo()}}), they will be used in the output from this function.} +be set through \code{\link[=setinfo]{setinfo()}}), they will be used in the output from this function. -\item{text}{Character vector previously generated by the function \code{\link[=xgb.dump]{xgb.dump()}} -(called with parameter \code{with_stats = TRUE}). \code{text} takes precedence over \code{model}.} +If the model contains categorical features, an error will be thrown.} \item{trees}{An integer vector of (base-1) tree indices that should be used. The default (\code{NULL}) uses all trees. Useful, e.g., in multiclass classification to get only @@ -59,6 +52,10 @@ the corresponding trees in the "Node" column. \description{ Parse a boosted tree model text dump into a \code{data.table} structure. } +\details{ +Note that this function does not work with models that were fitted to +categorical data, and is only applicable to tree-based boosters (not \code{gblinear}). +} \examples{ # Basic use: diff --git a/R-package/man/xgb.plot.multi.trees.Rd b/R-package/man/xgb.plot.multi.trees.Rd index 989096e60e42..a6be36b97d53 100644 --- a/R-package/man/xgb.plot.multi.trees.Rd +++ b/R-package/man/xgb.plot.multi.trees.Rd @@ -43,6 +43,9 @@ line. Visualization of the ensemble of trees as a single collective unit. } \details{ +Note that this function does not work with models that were fitted to +categorical data. + This function tries to capture the complexity of a gradient boosted tree model in a cohesive way by compressing an ensemble of trees into a single tree-graph representation. The goal is to improve the interpretability of a model generally seen as black box. diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index 26a555ea55e3..1e1054e73617 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -356,9 +356,8 @@ test_that("xgb.importance works with and without feature names", { importance <- xgb.importance(feature_names = feature.names, model = bst.Tree, trees = trees) importance_from_dump <- function() { - model_text_dump <- xgb.dump(model = bst.Tree, with_stats = TRUE) imp <- xgb.model.dt.tree( - text = model_text_dump, + model = bst.Tree, trees = trees )[ Feature != "Leaf", .(