Skip to content

Commit

Permalink
Make it possible to set the resulting model name from the TFDF Builde…
Browse files Browse the repository at this point in the history
…r APIs

PiperOrigin-RevId: 586072410
  • Loading branch information
SimpleML Team authored and copybara-github committed Nov 28, 2023
1 parent ed8d94e commit 5f821eb
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
11 changes: 10 additions & 1 deletion tensorflow_decision_forests/component/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def __init__(
file_prefix: Optional[str] = None,
verbose: int = 1,
advanced_arguments: Optional[AdvancedArguments] = None,
keras_model_name: Optional[str] = None,
):
if not path:
raise ValueError("The path cannot be empty")
Expand All @@ -208,6 +209,7 @@ def __init__(
if self._file_prefix is None:
self._file_prefix = keras_core.generate_training_id()
self._verbose = verbose
self._keras_model_name = keras_model_name

self._header.name = self.model_type()
self._header.task = objective.task
Expand Down Expand Up @@ -276,6 +278,7 @@ def close(self):
verbose=self._verbose,
disable_categorical_integer_offset_correction=self._advanced_arguments.disable_categorical_integer_offset_correction,
allow_slow_inference=self._advanced_arguments.allow_slow_inference,
keras_model_name=self._keras_model_name,
)
tf.io.gfile.rmtree(self.yggdrasil_model_path())

Expand Down Expand Up @@ -594,6 +597,7 @@ def __init__(
file_prefix: Optional[str] = None,
verbose: int = 1,
advanced_arguments: Optional[AdvancedArguments] = None,
keras_model_name: Optional[str] = None,
):
super(AbstractDecisionForestBuilder, self).__init__(
path,
Expand All @@ -604,6 +608,7 @@ def __init__(
file_prefix,
verbose,
advanced_arguments,
keras_model_name=keras_model_name,
)

self._trees = []
Expand Down Expand Up @@ -823,6 +828,7 @@ def __init__(
file_prefix: Optional[str] = None,
verbose: int = 1,
advanced_arguments: Optional[AdvancedArguments] = None,
keras_model_name: Optional[str] = None,
):
self._specialized_header = random_forest_pb2.Header(
winner_take_all_inference=winner_take_all
Expand All @@ -838,6 +844,7 @@ def __init__(
file_prefix,
verbose,
advanced_arguments,
keras_model_name,
)

def model_type(self) -> str:
Expand Down Expand Up @@ -913,6 +920,7 @@ def __init__(
file_prefix: Optional[str] = None,
verbose: int = 1,
advanced_arguments: Optional[AdvancedArguments] = None,
keras_model_name: Optional[str] = None,
):
# Compute the number of tree per iterations and loss.
#
Expand Down Expand Up @@ -974,7 +982,8 @@ def __init__(
file_prefix,
verbose,
advanced_arguments,
)
keras_model_name,
)

def model_type(self) -> str:
return "GRADIENT_BOOSTED_TREES"
Expand Down
34 changes: 24 additions & 10 deletions tensorflow_decision_forests/component/builder/builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,21 @@ def test_dataset_directory() -> str:

class BuilderTest(parameterized.TestCase, tf.test.TestCase):

@parameterized.parameters((None,), ("",), ("prefix_",))
def test_classification_random_forest(self, file_prefix):
@parameterized.parameters(
(None, None), ("", "abc123"), ("prefix_", "test_model")
)
def test_classification_random_forest(self, file_prefix, model_name):
model_path = os.path.join(tmp_path(), "classification_rf")
logging.info("Create model in %s", model_path)
builder = builder_lib.RandomForestBuilder(
path=model_path,
model_format=builder_lib.ModelFormat.TENSORFLOW_SAVED_MODEL,
objective=py_tree.objective.ClassificationObjective(
label="color", classes=["red", "blue", "green"]),
file_prefix=file_prefix)
label="color", classes=["red", "blue", "green"]
),
file_prefix=file_prefix,
keras_model_name=model_name,
)

# f1>=1.5
# │
Expand Down Expand Up @@ -118,7 +123,10 @@ def test_classification_random_forest(self, file_prefix):

logging.info("Loading model")
loaded_model = tf.keras.models.load_model(model_path)

expected_model_name = (
"inference_core_model" if model_name is None else model_name
)
self.assertEqual(loaded_model.name, expected_model_name)
logging.info("Make predictions")
tf_dataset = tf.data.Dataset.from_tensor_slices({
"f1": [1.0, 2.0, 3.0],
Expand All @@ -136,8 +144,11 @@ def test_classification_cart(self, file_prefix):
path=model_path,
model_format=builder_lib.ModelFormat.TENSORFLOW_SAVED_MODEL,
objective=py_tree.objective.ClassificationObjective(
label="color", classes=["red", "blue", "green"]),
file_prefix=file_prefix)
label="color", classes=["red", "blue", "green"]
),
file_prefix=file_prefix,
keras_model_name="classification_cart",
)

# f1>=1.5
# ├─(pos)─ f2 in ["cat","dog"]
Expand Down Expand Up @@ -178,7 +189,7 @@ def test_classification_cart(self, file_prefix):

logging.info("Loading model")
loaded_model = tf.keras.models.load_model(model_path)

self.assertEqual(loaded_model.name, "classification_cart")
logging.info("Make predictions")
tf_dataset = tf.data.Dataset.from_tensor_slices({
"f1": [1.0, 2.0, 3.0],
Expand Down Expand Up @@ -271,7 +282,10 @@ def test_binary_classification_gbt(self):
model_format=builder_lib.ModelFormat.TENSORFLOW_SAVED_MODEL,
bias=1.0,
objective=py_tree.objective.ClassificationObjective(
label="color", classes=["red", "blue"]))
label="color", classes=["red", "blue"]
),
keras_model_name="binary_classification_gbt",
)

# bias: 1.0 (toward "blue")
# f1>=1.5
Expand All @@ -294,7 +308,7 @@ def test_binary_classification_gbt(self):

logging.info("Loading model")
loaded_model = tf.keras.models.load_model(model_path)

self.assertEqual(loaded_model.name, "binary_classification_gbt")
logging.info("Make predictions")
tf_dataset = tf.data.Dataset.from_tensor_slices({
"f1": [1.0, 2.0],
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_decision_forests/keras/core_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ def yggdrasil_model_to_keras_model(
verbose: int = 1,
disable_categorical_integer_offset_correction: bool = False,
allow_slow_inference: bool = True,
keras_model_name: Optional[str] = None,
) -> None:
"""Converts an Yggdrasil model into a TensorFlow SavedModel / Keras model.
Expand All @@ -1229,6 +1230,7 @@ def yggdrasil_model_to_keras_model(
integer offset correction. See
disable_categorical_integer_offset_correction in AdvancedArguments for
more details.
name: Optional name of the resulting model.
"""

# Detect the container of the model.
Expand Down Expand Up @@ -1270,6 +1272,7 @@ def yggdrasil_model_to_keras_model(
disable_categorical_integer_offset_correction=disable_categorical_integer_offset_correction,
allow_slow_inference=allow_slow_inference,
),
name=keras_model_name,
)

model._set_from_yggdrasil_model( # pylint: disable=protected-access
Expand Down

0 comments on commit 5f821eb

Please sign in to comment.