diff --git a/tensorflow_decision_forests/component/py_tree/condition.py b/tensorflow_decision_forests/component/py_tree/condition.py index d672b7d..3aa978c 100644 --- a/tensorflow_decision_forests/component/py_tree/condition.py +++ b/tensorflow_decision_forests/component/py_tree/condition.py @@ -301,7 +301,7 @@ def core_condition_to_condition( core_condition.split_score) if condition_type.HasField("contains_condition"): - items = condition_type.contains_condition.elements + items = list(condition_type.contains_condition.elements) if not column_spec.categorical.is_already_integerized: items = [ dataspec_lib.categorical_value_idx_to_value(column_spec, item) diff --git a/tensorflow_decision_forests/keras/keras_test.py b/tensorflow_decision_forests/keras/keras_test.py index 3d09726..14e5a87 100644 --- a/tensorflow_decision_forests/keras/keras_test.py +++ b/tensorflow_decision_forests/keras/keras_test.py @@ -3060,6 +3060,16 @@ def test_abalone_poison_loss(self): predictions = model.predict(tf_test) logging.info("Predictions: %s", predictions) + def test_plot_contains_condition(self): + df = pd.DataFrame({"f": [-10, 0, 1000], "label": [0, 1, 2]}) + ds = keras.pd_dataframe_to_tf_dataset(df, label="label") + model = keras.CartModel( + min_examples=1, + features=[keras.FeatureUsage("f", keras.FeatureSemantic.CATEGORICAL)], + ) + model.fit(ds) + model_plotter.plot_model(model) + if __name__ == "__main__": tf.test.main()