@@ -343,6 +343,10 @@ def _import_dataspec(self, src_dataspec: data_spec_pb2.DataSpecification):
343
343
if dst_col_idx == self ._header .ranking_group_col_idx :
344
344
continue
345
345
346
+ if isinstance (self ._objective , py_tree .objective .AbstractUpliftObjective ):
347
+ if dst_col_idx == self ._header .uplift_treatment_col_idx :
348
+ continue
349
+
346
350
if not created :
347
351
raise ValueError (
348
352
"import_dataspec was called after some of the model was build. "
@@ -511,7 +515,11 @@ def _initialize_header_column_idx(self):
511
515
label_column .name = self ._objective .label
512
516
self ._dataspec_column_index [label_column .name ] = self ._header .label_col_idx
513
517
514
- if isinstance (self ._objective , py_tree .objective .ClassificationObjective ):
518
+ if isinstance (
519
+ self ._objective , py_tree .objective .ClassificationObjective
520
+ ) or isinstance (
521
+ self ._objective , py_tree .objective .CategoricalUpliftObjective
522
+ ):
515
523
label_column .type = ColumnType .CATEGORICAL
516
524
517
525
# One value is reserved for the non-used OOV item.
@@ -537,6 +545,7 @@ def _initialize_header_column_idx(self):
537
545
(
538
546
py_tree .objective .RegressionObjective ,
539
547
py_tree .objective .RankingObjective ,
548
+ py_tree .objective .NumericalUpliftObjective ,
540
549
),
541
550
):
542
551
label_column .type = ColumnType .NUMERICAL
@@ -556,6 +565,18 @@ def _initialize_header_column_idx(self):
556
565
self ._header .ranking_group_col_idx
557
566
)
558
567
568
+ if isinstance (self ._objective , py_tree .objective .AbstractUpliftObjective ):
569
+ assert len (self ._dataspec .columns ) == 1
570
+
571
+ # Create the "treatment" column for Uplifting.
572
+ self ._header .uplift_treatment_col_idx = 1
573
+ treatment_column = self ._dataspec .columns .add ()
574
+ treatment_column .type = ColumnType .CATEGORICAL
575
+ treatment_column .name = self ._objective .treatment
576
+ self ._dataspec_column_index [treatment_column .name ] = (
577
+ self ._header .uplift_treatment_col_idx
578
+ )
579
+
559
580
560
581
@six .add_metaclass (abc .ABCMeta )
561
582
class AbstractDecisionForestBuilder (AbstractBuilder ):
@@ -850,8 +871,11 @@ def check_leaf(self, node: py_tree.node.LeafNode):
850
871
"A regression objective requires leaf nodes with regressive values."
851
872
)
852
873
853
- elif isinstance (self .objective , py_tree .objective .RankingObjective ):
854
- raise ValueError ("Ranking objective not supported by this model" )
874
+ elif isinstance (self .objective , py_tree .objective .AbstractUpliftObjective ):
875
+ if not isinstance (node .value , py_tree .value .UpliftValue ):
876
+ raise ValueError (
877
+ "An uplift objective requires leaf nodes with uplift values."
878
+ )
855
879
856
880
else :
857
881
raise NotImplementedError ()
@@ -920,6 +944,11 @@ def __init__(
920
944
loss = gradient_boosted_trees_pb2 .Loss .LAMBDA_MART_NDCG5
921
945
bias = [bias ]
922
946
947
+ elif isinstance (objective , py_tree .objective .AbstractUpliftObjective ):
948
+ raise ValueError (
949
+ "Uplift objective not supported by Gradient Boosted Tree models."
950
+ )
951
+
923
952
else :
924
953
raise NotImplementedError ()
925
954
@@ -972,7 +1001,12 @@ def specialized_header_filename(self) -> str:
972
1001
return self ._file_prefix + inspector_lib .BASE_FILENAME_GBT_HEADER
973
1002
974
1003
def check_leaf (self , node : py_tree .node .LeafNode ):
975
- if not isinstance (node .value , py_tree .value .RegressionValue ):
1004
+ if isinstance (self .objective , py_tree .objective .AbstractUpliftObjective ):
1005
+ raise ValueError (
1006
+ "Uplift objective not supported by Gradient Boosted Tree models."
1007
+ )
1008
+
1009
+ elif not isinstance (node .value , py_tree .value .RegressionValue ):
976
1010
raise ValueError (
977
1011
"A GBT model should only have leaf with regressive "
978
1012
f"value. Got { node .value } instead."
0 commit comments