|
19 | 19 | """Cost model based on xgboost""" |
20 | 20 | import multiprocessing |
21 | 21 | import logging |
| 22 | +from typing import Dict |
22 | 23 | from collections import defaultdict |
23 | 24 |
|
24 | 25 | import numpy as np |
|
28 | 29 | from ..feature import get_per_store_features_from_measure_pairs, get_per_store_features_from_states |
29 | 30 | from ..measure_record import RecordReader |
30 | 31 |
|
| 32 | +try: |
| 33 | + from xgboost.callback import TrainingCallback # type: ignore |
| 34 | +except ImportError: |
| 35 | + |
| 36 | + class TrainingCallback: # type: ignore |
| 37 | + pass |
| 38 | + |
| 39 | + |
31 | 40 | xgb = None |
32 | 41 |
|
33 | 42 | logger = logging.getLogger("auto_scheduler") |
@@ -198,7 +207,7 @@ def update(self, inputs, results): |
198 | 207 | num_boost_round=10000, |
199 | 208 | obj=pack_sum_square_error, |
200 | 209 | callbacks=[ |
201 | | - custom_callback( |
| 210 | + CustomCallback( |
202 | 211 | stopping_rounds=50, |
203 | 212 | metric="tr-p-rmse", |
204 | 213 | fevals=[ |
@@ -539,125 +548,144 @@ def feval(preds, labels): |
539 | 548 | return feval |
540 | 549 |
|
541 | 550 |
|
542 | | -def custom_callback( |
543 | | - stopping_rounds, |
544 | | - metric, |
545 | | - fevals, |
546 | | - evals=(), |
547 | | - log_file=None, |
548 | | - maximize=False, |
549 | | - verbose_eval=True, |
550 | | - skip_every=2, |
551 | | -): |
552 | | - """Callback function for xgboost to support multiple custom evaluation functions""" |
553 | | - # pylint: disable=import-outside-toplevel |
554 | | - from xgboost.core import EarlyStopException |
555 | | - from xgboost.callback import _fmt_metric |
556 | | - |
557 | | - try: |
558 | | - from xgboost.training import aggcv |
559 | | - except ImportError: |
560 | | - from xgboost.callback import _aggcv as aggcv |
561 | | - |
562 | | - state = {} |
563 | | - metric_shortname = metric.split("-")[1] |
564 | | - |
565 | | - def init(env): |
566 | | - """internal function""" |
567 | | - bst = env.model |
568 | | - |
569 | | - state["maximize_score"] = maximize |
570 | | - state["best_iteration"] = 0 |
571 | | - if maximize: |
572 | | - state["best_score"] = float("-inf") |
573 | | - else: |
574 | | - state["best_score"] = float("inf") |
| 551 | +class XGBoostCallback(TrainingCallback): |
| 552 | + """Base class for XGBoost callbacks.""" |
575 | 553 |
|
576 | | - if bst is not None: |
577 | | - if bst.attr("best_score") is not None: |
578 | | - state["best_score"] = float(bst.attr("best_score")) |
579 | | - state["best_iteration"] = int(bst.attr("best_iteration")) |
580 | | - state["best_msg"] = bst.attr("best_msg") |
581 | | - else: |
582 | | - bst.set_attr(best_iteration=str(state["best_iteration"])) |
583 | | - bst.set_attr(best_score=str(state["best_score"])) |
584 | | - else: |
585 | | - assert env.cvfolds is not None |
| 554 | + def __call__(self, env: "xgb.core.CallbackEnv"): |
| 555 | + # Compatibility with xgboost < 1.3 |
| 556 | + return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) |
586 | 557 |
|
587 | | - def callback(env): |
588 | | - """internal function""" |
589 | | - if not state: |
590 | | - init(env) |
| 558 | + def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): |
| 559 | + raise NotImplementedError |
| 560 | + |
| 561 | + |
| 562 | +class CustomCallback(XGBoostCallback): |
| 563 | + """ |
| 564 | + Callback function for xgboost. |
| 565 | + Support custom evaluation function and early-stopping. |
| 566 | + """ |
| 567 | + |
| 568 | + def __init__( |
| 569 | + self, |
| 570 | + stopping_rounds, |
| 571 | + metric, |
| 572 | + fevals, |
| 573 | + evals=(), |
| 574 | + log_file=None, |
| 575 | + maximize=False, |
| 576 | + verbose_eval=True, |
| 577 | + skip_every=2, |
| 578 | + ): |
| 579 | + """Init function""" |
| 580 | + self.stopping_rounds = stopping_rounds |
| 581 | + self.metric = metric |
| 582 | + self.metric_shortname = metric.split("-")[1] |
| 583 | + self.fevals = fevals |
| 584 | + self.evals = evals |
| 585 | + self.log_file = log_file |
| 586 | + self.maximize = maximize |
| 587 | + self.verbose_eval = verbose_eval |
| 588 | + self.skip_every = skip_every |
| 589 | + self.state = {} |
591 | 590 |
|
592 | | - bst = env.model |
593 | | - i = env.iteration |
594 | | - cvfolds = env.cvfolds |
| 591 | + def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): |
| 592 | + """Run after each iteration. Return True when training should stop.""" |
| 593 | + # pylint:disable = import-outside-toplevel |
| 594 | + try: |
| 595 | + from xgboost.callback import _fmt_metric # type: ignore |
| 596 | + except ImportError: |
| 597 | + # Compatibility with xgboost >= 1.6 |
| 598 | + def _fmt_metric(value, show_stdv=True): |
| 599 | + """format metric string""" |
| 600 | + if len(value) == 2: |
| 601 | + return f"{value[0]}:{value[1]:.5f}" |
| 602 | + if len(value) == 3: |
| 603 | + if show_stdv: |
| 604 | + return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}" |
| 605 | + return f"{value[0]}:{value[1]:.5f}" |
| 606 | + raise ValueError("wrong metric value", value) |
| 607 | + |
| 608 | + ##### init state ##### |
| 609 | + if not self.state: |
| 610 | + self.state["maximize_score"] = self.maximize |
| 611 | + self.state["best_iteration"] = 0 |
| 612 | + if self.maximize: |
| 613 | + self.state["best_score"] = float("-inf") |
| 614 | + else: |
| 615 | + self.state["best_score"] = float("inf") |
595 | 616 |
|
| 617 | + assert model is not None |
| 618 | + if model.attr("best_score") is not None: |
| 619 | + self.state["best_score"] = float(model.attr("best_score")) |
| 620 | + self.state["best_iteration"] = int(model.attr("best_iteration")) |
| 621 | + self.state["best_msg"] = model.attr("best_msg") |
| 622 | + else: |
| 623 | + model.set_attr(best_iteration=str(self.state["best_iteration"])) |
| 624 | + model.set_attr(best_score=str(self.state["best_score"])) |
596 | 625 | res_dict = {} |
597 | 626 |
|
598 | | - if i % skip_every == 1: |
599 | | - return |
| 627 | + if epoch % self.skip_every == 1: |
| 628 | + return False |
600 | 629 |
|
601 | 630 | ##### evaluation ##### |
602 | | - if cvfolds is not None: |
603 | | - for feval in fevals: |
604 | | - tmp = aggcv([f.eval(i, feval) for f in cvfolds]) |
605 | | - for k, mean, std in tmp: |
606 | | - res_dict[k] = [mean, std] |
607 | | - else: |
608 | | - for feval in fevals: |
609 | | - bst_eval = bst.eval_set(evals, i, feval) |
610 | | - res = [x.split(":") for x in bst_eval.split()] |
611 | | - for kv in res[1:]: |
612 | | - res_dict[kv[0]] = [float(kv[1])] |
| 631 | + for feval in self.fevals: |
| 632 | + bst_eval = model.eval_set(self.evals, epoch, feval) |
| 633 | + res = [x.split(":") for x in bst_eval.split()] |
| 634 | + for kv in res[1:]: |
| 635 | + res_dict[kv[0]] = [float(kv[1])] |
613 | 636 |
|
614 | 637 | eval_res = [] |
615 | 638 | keys = list(res_dict.keys()) |
616 | | - keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x) |
| 639 | + keys.sort(key=lambda x: x if self.metric_shortname not in x else "a" + x) |
617 | 640 | for key in keys: |
618 | 641 | v = res_dict[key] |
619 | 642 | eval_res.append([key] + v) |
620 | 643 |
|
621 | 644 | ##### print eval result ##### |
622 | | - if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0: |
623 | | - infos = ["XGB iter: %3d" % i] |
| 645 | + if ( |
| 646 | + not isinstance(self.verbose_eval, bool) |
| 647 | + and self.verbose_eval |
| 648 | + and epoch % self.verbose_eval == 0 |
| 649 | + ): |
| 650 | + infos = ["XGB iter: %3d" % epoch] |
624 | 651 | for item in eval_res: |
625 | 652 | if "null" in item[0]: |
626 | 653 | continue |
627 | 654 | infos.append("%s: %.6f" % (item[0], item[1])) |
628 | 655 |
|
629 | 656 | logger.debug("\t".join(infos)) |
630 | | - if log_file: |
631 | | - with open(log_file, "a") as fout: |
| 657 | + if self.log_file: |
| 658 | + with open(self.log_file, "a") as fout: |
632 | 659 | fout.write("\t".join(infos) + "\n") |
633 | 660 |
|
634 | 661 | ##### choose score and do early stopping ##### |
635 | 662 | score = None |
636 | 663 | for item in eval_res: |
637 | | - if item[0] == metric: |
| 664 | + if item[0] == self.metric: |
638 | 665 | score = item[1] |
639 | 666 | break |
640 | 667 | assert score is not None |
641 | 668 |
|
642 | | - best_score = state["best_score"] |
643 | | - best_iteration = state["best_iteration"] |
644 | | - maximize_score = state["maximize_score"] |
| 669 | + best_score = self.state["best_score"] |
| 670 | + best_iteration = self.state["best_iteration"] |
| 671 | + maximize_score = self.state["maximize_score"] |
| 672 | + |
645 | 673 | if (maximize_score and score > best_score) or (not maximize_score and score < best_score): |
646 | | - msg = "[%d] %s" % (env.iteration, "\t".join([_fmt_metric(x) for x in eval_res])) |
647 | | - state["best_msg"] = msg |
648 | | - state["best_score"] = score |
649 | | - state["best_iteration"] = env.iteration |
| 674 | + msg = "[%d] %s" % (epoch, "\t".join([_fmt_metric(x) for x in eval_res])) |
| 675 | + self.state["best_msg"] = msg |
| 676 | + self.state["best_score"] = score |
| 677 | + self.state["best_iteration"] = epoch |
650 | 678 | # save the property to attributes, so they will occur in checkpoint. |
651 | | - if env.model is not None: |
652 | | - env.model.set_attr( |
653 | | - best_score=str(state["best_score"]), |
654 | | - best_iteration=str(state["best_iteration"]), |
655 | | - best_msg=state["best_msg"], |
| 679 | + if model is not None: |
| 680 | + model.set_attr( |
| 681 | + best_score=str(self.state["best_score"]), |
| 682 | + best_iteration=str(self.state["best_iteration"]), |
| 683 | + best_msg=self.state["best_msg"], |
656 | 684 | ) |
657 | | - elif env.iteration - best_iteration >= stopping_rounds: |
658 | | - best_msg = state["best_msg"] |
659 | | - if verbose_eval and env.rank == 0: |
| 685 | + elif epoch - best_iteration >= self.stopping_rounds: |
| 686 | + best_msg = self.state["best_msg"] |
| 687 | + if self.verbose_eval: |
660 | 688 | logger.debug("XGB stopped. Best iteration: %s ", best_msg) |
661 | | - raise EarlyStopException(best_iteration) |
| 689 | + return True |
662 | 690 |
|
663 | | - return callback |
| 691 | + return False |
0 commit comments