Skip to content

Commit

Permalink
Add Tabnet support (#168)
Browse files Browse the repository at this point in the history
* first commit for the addition of the TabDDPM plugin

* Add DDPM test script and update DDPM plugin

* add TabDDPM class and refactor

* handle discrete cols and label generation

* add hparam space and update tests of DDPM

* debug and test DDPM

* update TensorDataLoader and training loop

* clear bugs

* debug for regression tasks

* debug for regression tasks; ALL TESTS PASSED

* remove the official repo of TabDDPM

* passed all pre-commit checks

* convert assert to conditional AssertionErrors

* added an auto annotation tool

* update auto-anno and generate annotations

* remove auto-anno and flake8 noqa

* add python<3.9 compatible annotations

* remove star import

* replace builtin type annos to typing annos

* resolve py38 compatibility issue

* tests/plugins/generic/test_ddpm.py

* change TabDDPM method signatures

* remove Iterator subscription

* update AssertionErrors, add EarlyStop callback, removed additional MLP, update logging

* remove TensorDataLoader, update test_ddpm

* update EarlyStopping

* add TabDDPM tutorial, update TabDDPM plugin and encoders

* add TabDDPM tutorial

* major update of FeatureEncoder and TabularEncoder

* add LogDistribution and LogIntDistribution

* update DDPM to use TabularEncoder

* update test_tabular_encoder and debug

* debug and DDPM tutorial OK

* debug LogDistribution and LogIntDistribution

* change discrete encoding of BinEncoder to passthrough;  passed all tests in test_tabular_encoder

* add tabnet to plugins/core/models

* add factory.py, let DDPM use TabNet, refactor

* update docstrings and refactor

* fix type annotation compatibility

* make SkipConnection serializable

* fix TabularEncoder.activation_layout

* remove unnecessary code

* fix minor bug and add more nn models in factory

* update pandas and torch version requirement

* update pandas and torch version requirement

* update ddpm tutorial

* restore setup.cfg

* restore setup.cfg

* replace LabelEncoder with OrdinalEncoder

* update setup.cfg

* update setup.cfg

* debug datetimeDistribution

* clean

* update setup.cfg and goggle test

* move DDPM tutorial to tutorials/plugins

* update tabnet.py reference

* update tab_ddpm

* update

* try fixing goggle

* add more activations

* minor fix

* update

* update

* update

* update

* Update tabular_encoder.py

* Update test_goggle.py

* Update tabular_encoder.py

* update

* update

* default cat nonlin of goggle <- gumbel_softmax

* get_nonlin('softmax') <- GumbelSoftmax()

* remove debug logging

* update

* update

* fix merge

* update pip upgrade commands in workflows

* keep pip version to 23.0.1 in workflows

---------

Co-authored-by: Bogdan Cebere <[email protected]>
Co-authored-by: Rob <[email protected]>
  • Loading branch information
3 people authored Apr 20, 2023
1 parent 59108bf commit a4190e6
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ jobs:
if: ${{ matrix.os == 'macos-latest' }}
- name: Install dependencies
run: |
pip install pip==23.0.1
pip install -r prereq.txt
pip install --upgrade pip
- name: Test Core
run: |
pip install .[testing]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ jobs:
if: ${{ matrix.os == 'macos-latest' }}
- name: Install dependencies
run: |
pip install pip==23.0.1
pip install -r prereq.txt
pip install --upgrade pip
- name: Test Core
run: |
pip install .[testing]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ jobs:
if: ${{ matrix.os == 'macos-latest' }}
- name: Install dependencies
run: |
pip install pip==23.0.1
pip install -r prereq.txt
pip install --upgrade pip
pip install .[all]
Expand Down
4 changes: 2 additions & 2 deletions src/synthcity/plugins/core/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
DatetimeEncoder,
FeatureEncoder,
GaussianQuantileTransformer,
LabelEncoder,
MinMaxScaler,
OneHotEncoder,
OrdinalEncoder,
RobustScaler,
StandardScaler,
)
Expand Down Expand Up @@ -74,7 +74,7 @@
FEATURE_ENCODERS = dict(
datetime=DatetimeEncoder,
onehot=OneHotEncoder,
label=LabelEncoder,
ordinal=OrdinalEncoder,
standard=StandardScaler,
minmax=MinMaxScaler,
robust=RobustScaler,
Expand Down
5 changes: 5 additions & 0 deletions src/synthcity/plugins/core/models/tabnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# TabNet: Attentive Interpretable Tabular Learning
# Reference:
# - https://arxiv.org/pdf/1908.07442.pdf
# - https://github.com/dreamquark-ai/tabnet

# stdlib
from typing import List, Optional, Tuple

Expand Down

0 comments on commit a4190e6

Please sign in to comment.