Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* first commit for the addition of the TabDDPM plugin * Add DDPM test script and update DDPM plugin * add TabDDPM class and refactor * handle discrete cols and label generation * add hparam space and update tests of DDPM * debug and test DDPM * update TensorDataLoader and training loop * clear bugs * debug for regression tasks * debug for regression tasks; ALL TESTS PASSED * remove the official repo of TabDDPM * passed all pre-commit checks * convert assert to conditional AssertionErrors * added an auto annotation tool * update auto-anno and generate annotations * remove auto-anno and flake8 noqa * add python<3.9 compatible annotations * remove star import * replace builtin type annos to typing annos * resolve py38 compatibility issue * tests/plugins/generic/test_ddpm.py * change TabDDPM method signatures * remove Iterator subscription * update AssertionErrors, add EarlyStop callback, removed additional MLP, update logging * remove TensorDataLoader, update test_ddpm * update EarlyStopping * add TabDDPM tutorial, update TabDDPM plugin and encoders * add TabDDPM tutorial * major update of FeatureEncoder and TabularEncoder * add LogDistribution and LogIntDistribution * update DDPM to use TabularEncoder * update test_tabular_encoder and debug * debug and DDPM tutorial OK * debug LogDistribution and LogIntDistribution * change discrete encoding of BinEncoder to passthrough; passed all tests in test_tabular_encoder * add tabnet to plugins/core/models * add factory.py, let DDPM use TabNet, refactor * update docstrings and refactor * fix type annotation compatibility * make SkipConnection serializable * fix TabularEncoder.activation_layout * remove unnecessary code * fix minor bug and add more nn models in factory * update pandas and torch version requirement * update pandas and torch version requirement * update ddpm tutorial * restore setup.cfg * restore setup.cfg * replace LabelEncoder with OrdinalEncoder * update setup.cfg * update setup.cfg * debug datetimeDistribution * clean * update setup.cfg and goggle test * move DDPM tutorial to tutorials/plugins * update tabnet.py reference * update tab_ddpm * update * try fixing goggle * add more activations * minor fix * update * update * update * update * Update tabular_encoder.py * Update test_goggle.py * Update tabular_encoder.py * update * update * default cat nonlin of goggle <- gumbel_softmax * get_nonlin('softmax') <- GumbelSoftmax() * remove debug logging * update * update * fix merge * update pip upgrade commands in workflows * keep pip version to 23.0.1 in workflows --------- Co-authored-by: Bogdan Cebere <[email protected]> Co-authored-by: Rob <[email protected]>
- Loading branch information