Skip to content

Commit

Permalink
Update tree code (#2383)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Mar 6, 2024
1 parent a3f7724 commit 0c30363
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 39 deletions.
2 changes: 1 addition & 1 deletion nvflare/app_opt/xgboost/tree_based/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_opt.xgboost.data_loader import XGBDataLoader
from nvflare.app_opt.xgboost.tree_based.shareable_generator import update_model
from nvflare.app_opt.xgboost.tree_based.utils import update_model
from nvflare.fuel.utils.import_utils import optional_import
from nvflare.security.logging import secure_format_exception

Expand Down
39 changes: 1 addition & 38 deletions nvflare/app_opt/xgboost/tree_based/shareable_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,7 @@
from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, model_learnable_to_dxo
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator
from nvflare.app_common.app_constant import AppConstants


def _get_xgboost_model_attr(xgb_model):
num_parallel_tree = int(
xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_parallel_tree"]
)
num_trees = int(xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"])
return num_parallel_tree, num_trees


def update_model(prev_model, model_update):
if not prev_model:
return model_update
else:
# Append all trees
# get the parameters
pre_num_parallel_tree, pre_num_trees = _get_xgboost_model_attr(prev_model)
cur_num_parallel_tree, add_num_trees = _get_xgboost_model_attr(model_update)

# check num_parallel_tree, should be consistent
if cur_num_parallel_tree != pre_num_parallel_tree:
raise ValueError(
f"add_num_parallel_tree should not change, previous {pre_num_parallel_tree}, current {add_num_parallel_tree}"
)
prev_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"] = str(
pre_num_trees + cur_num_parallel_tree
)
# append the new trees
append_info = model_update["learner"]["gradient_booster"]["model"]["trees"]
for tree_ct in range(cur_num_parallel_tree):
append_info[tree_ct]["id"] = pre_num_trees + tree_ct
prev_model["learner"]["gradient_booster"]["model"]["trees"].append(append_info[tree_ct])
prev_model["learner"]["gradient_booster"]["model"]["tree_info"].append(0)
# append iteration_indptr
prev_model["learner"]["gradient_booster"]["model"]["iteration_indptr"].append(
pre_num_trees + cur_num_parallel_tree
)
return prev_model
from nvflare.app_opt.xgboost.tree_based.utils import update_model


class XGBModelShareableGenerator(ShareableGenerator):
Expand Down
51 changes: 51 additions & 0 deletions nvflare/app_opt/xgboost/tree_based/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def _get_xgboost_model_attr(xgb_model):
num_parallel_tree = int(
xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_parallel_tree"]
)
num_trees = int(xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"])
return num_parallel_tree, num_trees


def update_model(prev_model, model_update):
if not prev_model:
return model_update
else:
# Append all trees
# get the parameters
pre_num_parallel_tree, pre_num_trees = _get_xgboost_model_attr(prev_model)
cur_num_parallel_tree, add_num_trees = _get_xgboost_model_attr(model_update)

# check num_parallel_tree, should be consistent
if cur_num_parallel_tree != pre_num_parallel_tree:
raise ValueError(
f"add_num_parallel_tree should not change, previous {pre_num_parallel_tree}, current {cur_num_parallel_tree}"
)
prev_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"] = str(
pre_num_trees + cur_num_parallel_tree
)
# append the new trees
append_info = model_update["learner"]["gradient_booster"]["model"]["trees"]
for tree_ct in range(cur_num_parallel_tree):
append_info[tree_ct]["id"] = pre_num_trees + tree_ct
prev_model["learner"]["gradient_booster"]["model"]["trees"].append(append_info[tree_ct])
prev_model["learner"]["gradient_booster"]["model"]["tree_info"].append(0)
# append iteration_indptr
prev_model["learner"]["gradient_booster"]["model"]["iteration_indptr"].append(
pre_num_trees + cur_num_parallel_tree
)
return prev_model

0 comments on commit 0c30363

Please sign in to comment.