Skip to content

Commit 244576e

Browse files
rstzcopybara-github
authored andcommitted
Add support for uplifting in the inspector.
This allows exporting uplift models from YDF to TF PiperOrigin-RevId: 579181978
1 parent c6fee4a commit 244576e

File tree

6 files changed

+150
-6
lines changed

6 files changed

+150
-6
lines changed

tensorflow_decision_forests/component/builder/builder.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,10 @@ def _import_dataspec(self, src_dataspec: data_spec_pb2.DataSpecification):
343343
if dst_col_idx == self._header.ranking_group_col_idx:
344344
continue
345345

346+
if isinstance(self._objective, py_tree.objective.AbstractUpliftObjective):
347+
if dst_col_idx == self._header.uplift_treatment_col_idx:
348+
continue
349+
346350
if not created:
347351
raise ValueError(
348352
"import_dataspec was called after some of the model was build. "
@@ -511,7 +515,11 @@ def _initialize_header_column_idx(self):
511515
label_column.name = self._objective.label
512516
self._dataspec_column_index[label_column.name] = self._header.label_col_idx
513517

514-
if isinstance(self._objective, py_tree.objective.ClassificationObjective):
518+
if isinstance(
519+
self._objective, py_tree.objective.ClassificationObjective
520+
) or isinstance(
521+
self._objective, py_tree.objective.CategoricalUpliftObjective
522+
):
515523
label_column.type = ColumnType.CATEGORICAL
516524

517525
# One value is reserved for the non-used OOV item.
@@ -537,6 +545,7 @@ def _initialize_header_column_idx(self):
537545
(
538546
py_tree.objective.RegressionObjective,
539547
py_tree.objective.RankingObjective,
548+
py_tree.objective.NumericalUpliftObjective,
540549
),
541550
):
542551
label_column.type = ColumnType.NUMERICAL
@@ -556,6 +565,18 @@ def _initialize_header_column_idx(self):
556565
self._header.ranking_group_col_idx
557566
)
558567

568+
if isinstance(self._objective, py_tree.objective.AbstractUpliftObjective):
569+
assert len(self._dataspec.columns) == 1
570+
571+
# Create the "treatment" column for Uplifting.
572+
self._header.uplift_treatment_col_idx = 1
573+
treatment_column = self._dataspec.columns.add()
574+
treatment_column.type = ColumnType.CATEGORICAL
575+
treatment_column.name = self._objective.treatment
576+
self._dataspec_column_index[treatment_column.name] = (
577+
self._header.uplift_treatment_col_idx
578+
)
579+
559580

560581
@six.add_metaclass(abc.ABCMeta)
561582
class AbstractDecisionForestBuilder(AbstractBuilder):
@@ -850,8 +871,11 @@ def check_leaf(self, node: py_tree.node.LeafNode):
850871
"A regression objective requires leaf nodes with regressive values."
851872
)
852873

853-
elif isinstance(self.objective, py_tree.objective.RankingObjective):
854-
raise ValueError("Ranking objective not supported by this model")
874+
elif isinstance(self.objective, py_tree.objective.AbstractUpliftObjective):
875+
if not isinstance(node.value, py_tree.value.UpliftValue):
876+
raise ValueError(
877+
"An uplift objective requires leaf nodes with uplift values."
878+
)
855879

856880
else:
857881
raise NotImplementedError()
@@ -920,6 +944,11 @@ def __init__(
920944
loss = gradient_boosted_trees_pb2.Loss.LAMBDA_MART_NDCG5
921945
bias = [bias]
922946

947+
elif isinstance(objective, py_tree.objective.AbstractUpliftObjective):
948+
raise ValueError(
949+
"Uplift objective not supported by Gradient Boosted Tree models."
950+
)
951+
923952
else:
924953
raise NotImplementedError()
925954

@@ -972,7 +1001,12 @@ def specialized_header_filename(self) -> str:
9721001
return self._file_prefix + inspector_lib.BASE_FILENAME_GBT_HEADER
9731002

9741003
def check_leaf(self, node: py_tree.node.LeafNode):
975-
if not isinstance(node.value, py_tree.value.RegressionValue):
1004+
if isinstance(self.objective, py_tree.objective.AbstractUpliftObjective):
1005+
raise ValueError(
1006+
"Uplift objective not supported by Gradient Boosted Tree models."
1007+
)
1008+
1009+
elif not isinstance(node.value, py_tree.value.RegressionValue):
9761010
raise ValueError(
9771011
"A GBT model should only have leaf with regressive "
9781012
f"value. Got {node.value} instead."

tensorflow_decision_forests/component/inspector/inspector.py

+16
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,22 @@ def objective(self) -> py_tree.objective.AbstractObjective:
252252
return py_tree.objective.RankingObjective(
253253
label=label.name, group=group_column.name)
254254

255+
elif self.task == Task.CATEGORICAL_UPLIFT:
256+
uplift_treatment = self._dataspec.columns[
257+
self._header.uplift_treatment_col_idx
258+
]
259+
return py_tree.objective.CategoricalUpliftObjective(
260+
label=label.name, treatment=uplift_treatment.name
261+
)
262+
263+
elif self.task == Task.NUMERICAL_UPLIFT:
264+
uplift_treatment = self._dataspec.columns[
265+
self._header.uplift_treatment_col_idx
266+
]
267+
return py_tree.objective.NumericalUpliftObjective(
268+
label=label.name, treatment=uplift_treatment.name
269+
)
270+
255271
else:
256272
raise NotImplementedError()
257273

tensorflow_decision_forests/component/py_tree/objective.py

+43
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,46 @@ def __eq__(self, other):
155155
if not isinstance(other, RankingObjective):
156156
return False
157157
return self.label == other.label and self._group == other._group
158+
159+
160+
class AbstractUpliftObjective(AbstractObjective):
161+
"""Objective for Uplift."""
162+
163+
def __init__(self, label: str, treatment: str):
164+
super(AbstractUpliftObjective, self).__init__(label)
165+
self._treatment = treatment
166+
167+
@property
168+
def treatment(self) -> str:
169+
return self._treatment
170+
171+
def __eq__(self, other):
172+
if not isinstance(other, AbstractUpliftObjective):
173+
return False
174+
return (
175+
self.label == other.label
176+
and self._treatment == other._treatment
177+
and self.task == other.task
178+
)
179+
180+
181+
class CategoricalUpliftObjective(AbstractUpliftObjective):
182+
"""Objective for Categorical Uplift."""
183+
184+
@property
185+
def task(self) -> Task:
186+
return Task.CATEGORICAL_UPLIFT
187+
188+
def __repr__(self):
189+
return f"CategoricalUplift(label={self.label}, treatment={self._treatment})"
190+
191+
192+
class NumericalUpliftObjective(AbstractUpliftObjective):
193+
"""Objective for Numerical Uplift."""
194+
195+
@property
196+
def task(self) -> Task:
197+
return Task.NUMERICAL_UPLIFT
198+
199+
def __repr__(self):
200+
return f"NumericalUplift(label={self.label}, treatment={self._treatment})"

tensorflow_decision_forests/component/py_tree/objective_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,27 @@ def test_ranking(self):
6666
objective = objective_lib.RankingObjective(label="label", group="group")
6767
logging.info("objective: %s", objective)
6868

69+
def test_numerical_uplift(self):
70+
objective = objective_lib.NumericalUpliftObjective(
71+
label="label", treatment="treatment"
72+
)
73+
logging.info("objective: %s", objective)
74+
75+
def test_categorical_uplift(self):
76+
objective = objective_lib.CategoricalUpliftObjective(
77+
label="label", treatment="treatment"
78+
)
79+
logging.info("objective: %s", objective)
80+
81+
def test_uplift_objects_are_not_equal(self):
82+
numerical_objective = objective_lib.NumericalUpliftObjective(
83+
label="label", treatment="treatment"
84+
)
85+
categorical_objective = objective_lib.CategoricalUpliftObjective(
86+
label="label", treatment="treatment"
87+
)
88+
self.assertNotEqual(numerical_objective, categorical_objective)
89+
6990

7091
if __name__ == "__main__":
7192
tf.test.main()

tensorflow_decision_forests/keras/core_inference.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,6 @@ def yggdrasil_model_to_keras_model(
12391239
"file containing among other things, a data_spec.pb file."
12401240
)
12411241

1242-
temp_directory = None
12431242
if src_container == "zip":
12441243
# Unzip the model in a temporary directory
12451244
temp_directory = tempfile.TemporaryDirectory()
@@ -1255,6 +1254,13 @@ def yggdrasil_model_to_keras_model(
12551254
ranking_group=objective.group
12561255
if objective.task == inspector_lib.Task.RANKING
12571256
else None,
1257+
uplift_treatment=objective.treatment
1258+
if objective.task
1259+
in (
1260+
inspector_lib.Task.CATEGORICAL_UPLIFT,
1261+
inspector_lib.Task.NUMERICAL_UPLIFT,
1262+
)
1263+
else None,
12581264
verbose=verbose,
12591265
advanced_arguments=AdvancedArguments(
12601266
disable_categorical_integer_offset_correction=disable_categorical_integer_offset_correction,

tensorflow_decision_forests/keras/keras_test.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -2293,7 +2293,9 @@ def test_golden_model_gbt(self):
22932293
("adult_binary_class_gbdt", 0.012131),
22942294
("prefixed_adult_binary_class_gbdt", 0.012131),
22952295
)
2296-
def test_ydf_to_keras_model(self, ydf_model_directory, expected_prediction):
2296+
def test_ydf_to_keras_model_adult(
2297+
self, ydf_model_directory, expected_prediction
2298+
):
22972299
ygg_model_path = os.path.join(
22982300
ydf_test_data_path(), "model", ydf_model_directory
22992301
)
@@ -2328,6 +2330,28 @@ def custom_model_input_signature(
23282330
)
23292331
self.assertNear(prediction[0, 0], expected_prediction, 0.00001)
23302332

2333+
def test_ydf_to_keras_model_uplift(self):
2334+
ygg_model_path = os.path.join(
2335+
ydf_test_data_path(), "model", "sim_pte_categorical_uplift_rf"
2336+
)
2337+
tfdf_model_path = os.path.join(tmp_path(), "sim_pte_categorical_uplift_rf")
2338+
2339+
dataset_directory = os.path.join(ydf_test_data_path(), "dataset")
2340+
test_path = os.path.join(dataset_directory, "sim_pte_test.csv")
2341+
test_df = pd.read_csv(test_path)
2342+
2343+
outcome_key = "y"
2344+
treatment_group = "treat"
2345+
# Remove the treatment group from the test dataset.
2346+
test_df = test_df.drop(treatment_group, axis=1)
2347+
2348+
core.yggdrasil_model_to_keras_model(ygg_model_path, tfdf_model_path)
2349+
loaded_model = models.load_model(tfdf_model_path)
2350+
prediction = loaded_model.predict(
2351+
keras.pd_dataframe_to_tf_dataset(test_df, label=outcome_key)
2352+
)
2353+
self.assertNear(prediction[0, 0], -0.7580058, 0.00001)
2354+
23312355
@parameterized.parameters(
23322356
"directory",
23332357
"zip",

0 commit comments

Comments
 (0)