Skip to content

Commit

Permalink
Constraints improvements: categoricals bugfixing (#105)
Browse files Browse the repository at this point in the history
* fix constraints for categoricals

* bugfixing

* improve docs
  • Loading branch information
bcebere authored Jan 17, 2023
1 parent 4d8f72e commit 3563c8e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 15 deletions.
17 changes: 17 additions & 0 deletions src/synthcity/plugins/core/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,31 @@ def feature_params(self, feature: str) -> Tuple:
dist_template = "integer"
elif (op == "le" or op == "<=") and value < dist_args["high"]:
dist_args["high"] = value
if "choices" in dist_args:
dist_args["choices"] = [
v for v in dist_args["choices"] if v <= value
]
elif (op == "lt" or op == "<") and value < dist_args["high"]:
dist_args["high"] = value - 1
if "choices" in dist_args:
dist_args["choices"] = [
v for v in dist_args["choices"] if v < value
]
elif (op == "ge" or op == ">=") and dist_args["low"] < value:
dist_args["low"] = value
if "choices" in dist_args:
dist_args["choices"] = [
v for v in dist_args["choices"] if v >= value
]
elif (op == "gt" or op == ">") and dist_args["low"] < value:
dist_args["low"] = value + 1
if "choices" in dist_args:
dist_args["choices"] = [
v for v in dist_args["choices"] if v > value
]
elif op == "eq" or op == "==":
dist_args["low"] = value
dist_args["high"] = value
dist_args["choices"] = [value]

return dist_template, dist_args
20 changes: 16 additions & 4 deletions src/synthcity/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,23 @@ def generate(
Args:
count: optional int.
The number of samples to generate. If None, it generated len(reference_dataset) samples.
constraints: optional Constraints
cond: Optional, Union[pd.DataFrame, pd.Series, np.ndarray].
Optional Generation Conditional. The conditional can be used only if the model was trained using a conditional too.
If provided, it must have `count` length.
Not all models support conditionals. The conditionals can be used in VAEs or GANs to speed-up the generation under some constraints. For model agnostic solutions, check out the `constraints` parameter.
constraints: optional Constraints.
Optional constraints to apply on the generated data. If none, the reference schema constraints are applied. The constraints are model agnostic, and will filter the output of the generative model.
The constraints are a list of rules. Each rule is a tuple of the form (<feature>, <operation>, <value>).
Valid Operations:
- "<", "lt" : less than <value>
- "<=", "le": less or equal with <value>
- ">", "gt" : greater than <value>
- ">=", "ge": greater or equal with <value>
- "==", "eq": equal with <value>
- "in": valid for categorical features, and <value> must be array. for example, ("target", "in", [0, 1])
- "dtype": <value> can be a data type. For example, ("target", "dtype", "int")
Usage example:
>>> from synthcity.plugins.core.constraints import Constraints
>>> constraints = Constraints(
Expand All @@ -286,9 +301,6 @@ def generate(
>>>
>>> assert (syn_data["InterestingFeature"] == 0).all()
cond: Optional, Union[pd.DataFrame, pd.Series, np.ndarray]
Optional Generation Conditional. The conditional can be used only if the model was trained using a conditional too.
If provided, it must have `count` length.
Returns:
<count> synthetic samples
"""
Expand Down
2 changes: 1 addition & 1 deletion src/synthcity/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.4"
__version__ = "0.1.5"

MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
MINOR_VERSION = __version__.split(".")[-1]
15 changes: 5 additions & 10 deletions tests/plugins/generic/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,22 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:
@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
def test_plugin_generate_constraints(test_plugin: Plugin) -> None:
X = pd.DataFrame(load_iris()["data"])
def test_plugin_generate_constraints_ctgan(test_plugin: Plugin) -> None:
X, y = load_iris(as_frame=True, return_X_y=True)
X["target"] = y
test_plugin.fit(GenericDataLoader(X))

constraints = Constraints(
rules=[
("0", "le", 6),
("0", "ge", 4.3),
("1", "le", 4.4),
("1", "ge", 3),
("2", "le", 5.5),
("2", "ge", 1.0),
("3", "le", 2),
("3", "ge", 0.1),
("target", "eq", 1),
]
)

X_gen = test_plugin.generate(constraints=constraints).dataframe()
assert len(X_gen) == len(X)
assert test_plugin.schema_includes(X_gen)
assert constraints.filter(X_gen).sum() == len(X_gen)
assert (X_gen["target"] == 1).all()

X_gen = test_plugin.generate(count=50, constraints=constraints).dataframe()
assert len(X_gen) == 50
Expand Down

0 comments on commit 3563c8e

Please sign in to comment.