Skip to content

Commit

Permalink
Fix plotting contains conditions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588330281
  • Loading branch information
rstz authored and copybara-github committed Dec 6, 2023
1 parent fde0402 commit bfd63a5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensorflow_decision_forests/component/py_tree/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def core_condition_to_condition(
core_condition.split_score)

if condition_type.HasField("contains_condition"):
items = condition_type.contains_condition.elements
items = list(condition_type.contains_condition.elements)
if not column_spec.categorical.is_already_integerized:
items = [
dataspec_lib.categorical_value_idx_to_value(column_spec, item)
Expand Down
10 changes: 10 additions & 0 deletions tensorflow_decision_forests/keras/keras_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3060,6 +3060,16 @@ def test_abalone_poison_loss(self):
predictions = model.predict(tf_test)
logging.info("Predictions: %s", predictions)

def test_plot_contains_condition(self):
df = pd.DataFrame({"f": [-10, 0, 1000], "label": [0, 1, 2]})
ds = keras.pd_dataframe_to_tf_dataset(df, label="label")
model = keras.CartModel(
min_examples=1,
features=[keras.FeatureUsage("f", keras.FeatureSemantic.CATEGORICAL)],
)
model.fit(ds)
model_plotter.plot_model(model)


if __name__ == "__main__":
tf.test.main()

0 comments on commit bfd63a5

Please sign in to comment.