-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] 58 add goggle #162
[WIP] 58 add goggle #162
Conversation
**kwargs: Any, | ||
) -> None: | ||
kwargs.setdefault("aggr", aggr) | ||
super().__init__(node_dim=0, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this node_dim
a parameter or needs to be hardcoded to 3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have exposed it as a parameter now.
**kwargs, | ||
) | ||
|
||
def enforce_constraints(self, X_synth: torch.Tensor) -> np.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, will delete in next push.
tests/plugins/generic/test_goggle.py
Outdated
assert np.mean(results) > 0.7 | ||
|
||
|
||
# @pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are no plans for conditional generation, you should delete these comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am hoping to add conditional generation as a later date, but will delete in the next push as I haven't seen the research code for it yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
.github/workflows/test_full.yml
Outdated
@@ -28,6 +28,7 @@ jobs: | |||
- name: Install dependencies | |||
run: | | |||
pip install -r prereq.txt | |||
pip install .[goggle] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the separate install profile? Isn't google included in the default library?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Talking with Evgeny, this seems like a better way to handle the dependencies. This issue is that torch_scatter
and torch_sparse
need to be installed from source with an --find-links flag or similar. These dependencies are very tricky
src/synthcity/plugins/core/plugin.py
Outdated
@@ -548,8 +548,22 @@ def __init__(self, plugins: list, expected_type: Type, categories: list) -> None | |||
self._available_plugins = {} | |||
for plugin in plugins: | |||
stem = Path(plugin).stem.split("plugin_")[-1] | |||
self._available_plugins[stem] = plugin | |||
|
|||
if stem == "goggle": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't include specific plugin conditions in plugin.py
, this should stay generic.
- add `all` install target - drop the goggle mark for tests and the checks from the plugin. torch-sparse requires compiling. `pip` is not aware of local computer architectures, so you need to compile when installing with pip. `conda install` can handle the computer arch, but you need to be sure you match the torch version, otherwise it will crash. For now, pip compiles the library everytime. However, the github actions are caching the `pip` stage, so it won't take forever everytime.
Description
Added synthetic data model "goggle", this version does not support conditional generation. Closes #58 .
Goggle paper: "GOGGLE: Generative Modelling for Tabular Data by Learning Relational Structure" Authors: Tennison Liu, Zhaozhi Qian, Jeroen Berrevoets, Mihaela van der Schaar
Affected Dependencies
How has this been tested?
- test_plugin_sanity
- test_plugin_name
- test_plugin_type
- test_plugin_hyperparams
- test_plugin_fit
- test_plugin_generate
- test_plugin_generate_constraints_goggle
- test_sample_hyperparams
- test_eval_performance_goggle
- test_plugin_encoding
pytest -vvvs tests/plugins/generic/test_goggle.py
Checklist