Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tabnet support #168

Merged
merged 91 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
2ecba4f
first commit for the addition of the TabDDPM plugin
Mar 1, 2023
fed898b
Add DDPM test script and update DDPM plugin
Mar 3, 2023
34979cf
add TabDDPM class and refactor
Mar 5, 2023
0abdc01
handle discrete cols and label generation
Mar 7, 2023
405a052
add hparam space and update tests of DDPM
Mar 7, 2023
0e36041
debug and test DDPM
Mar 7, 2023
fc9cee0
update TensorDataLoader and training loop
Mar 7, 2023
d8b57ad
clear bugs
Mar 7, 2023
92dcc32
debug for regression tasks
Mar 7, 2023
0b9d0e3
debug for regression tasks; ALL TESTS PASSED
Mar 7, 2023
6e6fe41
Merge branch 'tab_ddpm' of https://github.com/TZTsai/synthcity into t…
Mar 7, 2023
bb98229
remove the official repo of TabDDPM
Mar 7, 2023
b4486a4
passed all pre-commit checks
Mar 8, 2023
2a9aa2a
convert assert to conditional AssertionErrors
Mar 8, 2023
246cd5b
added an auto annotation tool
Mar 10, 2023
f458bb4
update auto-anno and generate annotations
Mar 10, 2023
137c176
remove auto-anno and flake8 noqa
Mar 10, 2023
6c4af11
add python<3.9 compatible annotations
Mar 10, 2023
191cdcc
remove star import
Mar 10, 2023
9349a66
replace builtin type annos to typing annos
Mar 12, 2023
02579e9
resolve py38 compatibility issue
Mar 12, 2023
f930bc0
tests/plugins/generic/test_ddpm.py
Mar 12, 2023
3cf73d7
change TabDDPM method signatures
Mar 13, 2023
5d37c4b
remove Iterator subscription
Mar 13, 2023
681ba60
update AssertionErrors, add EarlyStop callback, removed additional ML…
Mar 15, 2023
bcbc131
Merge branch 'main' into tab_ddpm
bcebere Mar 15, 2023
a9438dc
remove TensorDataLoader, update test_ddpm
Mar 16, 2023
52be80f
update EarlyStopping
Mar 16, 2023
794ebd6
add TabDDPM tutorial, update TabDDPM plugin and encoders
Mar 27, 2023
bcdce4b
add TabDDPM tutorial
Mar 27, 2023
8120e97
major update of FeatureEncoder and TabularEncoder
Mar 30, 2023
2750791
add LogDistribution and LogIntDistribution
Mar 30, 2023
52011d3
update DDPM to use TabularEncoder
Mar 30, 2023
0ee6c8b
update test_tabular_encoder and debug
Mar 30, 2023
244854d
debug and DDPM tutorial OK
Mar 30, 2023
e336d3c
Merge branch 'main' of https://github.com/vanderschaarlab/synthcity
Mar 30, 2023
c847c95
Merge branch 'main' into tab_ddpm
Mar 30, 2023
428177b
debug LogDistribution and LogIntDistribution
Mar 31, 2023
3377a95
Merge branch 'main' into tab_ddpm
Mar 31, 2023
4705319
change discrete encoding of BinEncoder to passthrough; passed all te…
Apr 1, 2023
d9d73f1
add tabnet to plugins/core/models
Apr 2, 2023
d29ef37
add factory.py, let DDPM use TabNet, refactor
Apr 2, 2023
6e58cf3
update docstrings and refactor
Apr 2, 2023
2a6ca6f
fix type annotation compatibility
Apr 2, 2023
36acaa0
make SkipConnection serializable
Apr 3, 2023
de15b9b
fix TabularEncoder.activation_layout
Apr 3, 2023
694cd22
remove unnecessary code
Apr 3, 2023
a459785
fix minor bug and add more nn models in factory
Apr 6, 2023
57816b6
update pandas and torch version requirement
Apr 6, 2023
cc7e8fb
update pandas and torch version requirement
Apr 6, 2023
f20db25
Merge branch 'main' into tabnet
Apr 6, 2023
7b0c19a
Merge branch 'main' into tab_ddpm
Apr 6, 2023
8a58996
update ddpm tutorial
Apr 6, 2023
31b5f13
Merge branch 'tab_ddpm' of https://github.com/TZTsai/synthcity into t…
Apr 6, 2023
cef348e
restore setup.cfg
Apr 6, 2023
9cb5da1
restore setup.cfg
Apr 6, 2023
fe5ff25
replace LabelEncoder with OrdinalEncoder
Apr 7, 2023
2922a1d
update setup.cfg
Apr 7, 2023
11fb825
update setup.cfg
Apr 7, 2023
9222b4e
debug datetimeDistribution
Apr 7, 2023
7d55c65
Merge branch 'tab_ddpm' into tabnet
Apr 7, 2023
95302b9
clean
Apr 7, 2023
785db82
update setup.cfg and goggle test
Apr 7, 2023
44ead6d
Merge branch 'tab_ddpm' into tabnet
Apr 7, 2023
27cc95c
move DDPM tutorial to tutorials/plugins
Apr 7, 2023
1d7c77c
update tabnet.py reference
Apr 7, 2023
6c25377
update tab_ddpm
Apr 7, 2023
2fb8508
update
Apr 8, 2023
4a7e73b
try fixing goggle
Apr 8, 2023
8051caa
add more activations
Apr 8, 2023
3cd9917
minor fix
Apr 8, 2023
42cbe8c
update
Apr 9, 2023
101c76f
Merge branch 'tab_ddpm' into tabnet
Apr 9, 2023
7c58f2d
update
Apr 9, 2023
104e3a3
update
Apr 9, 2023
7b4e04a
update
Apr 9, 2023
fede549
Update tabular_encoder.py
Apr 10, 2023
539effa
Update test_goggle.py
Apr 10, 2023
0cb9f25
Update tabular_encoder.py
Apr 10, 2023
42c6941
update
Apr 10, 2023
e20e581
update
Apr 10, 2023
472ad52
default cat nonlin of goggle <- gumbel_softmax
Apr 10, 2023
5dbe666
get_nonlin('softmax') <- GumbelSoftmax()
Apr 10, 2023
74e897b
remove debug logging
Apr 10, 2023
27553e9
update
Apr 10, 2023
b5eb2e7
Merge branch 'tab_ddpm' into tabnet
Apr 10, 2023
7fc5ce4
update
Apr 10, 2023
8af4966
Merge branch 'main' into tabnet
robsdavis Apr 18, 2023
ecc9d08
fix merge
Apr 18, 2023
1d9c7a4
update pip upgrade commands in workflows
Apr 19, 2023
385d2ed
keep pip version to 23.0.1 in workflows
Apr 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ jobs:
if: ${{ matrix.os == 'macos-latest' }}
- name: Install dependencies
run: |
pip install pip==23.0.1
pip install -r prereq.txt
pip install --upgrade pip
- name: Test Core
run: |
pip install .[testing]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ jobs:
if: ${{ matrix.os == 'macos-latest' }}
- name: Install dependencies
run: |
pip install pip==23.0.1
pip install -r prereq.txt
pip install --upgrade pip
- name: Test Core
run: |
pip install .[testing]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ jobs:
if: ${{ matrix.os == 'macos-latest' }}
- name: Install dependencies
run: |
pip install pip==23.0.1
pip install -r prereq.txt
pip install --upgrade pip

pip install .[all]

Expand Down
4 changes: 2 additions & 2 deletions src/synthcity/plugins/core/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
DatetimeEncoder,
FeatureEncoder,
GaussianQuantileTransformer,
LabelEncoder,
MinMaxScaler,
OneHotEncoder,
OrdinalEncoder,
RobustScaler,
StandardScaler,
)
Expand Down Expand Up @@ -74,7 +74,7 @@
FEATURE_ENCODERS = dict(
datetime=DatetimeEncoder,
onehot=OneHotEncoder,
label=LabelEncoder,
ordinal=OrdinalEncoder,
standard=StandardScaler,
minmax=MinMaxScaler,
robust=RobustScaler,
Expand Down
5 changes: 5 additions & 0 deletions src/synthcity/plugins/core/models/tabnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# TabNet: Attentive Interpretable Tabular Learning
# Reference:
# - https://arxiv.org/pdf/1908.07442.pdf
# - https://github.com/dreamquark-ai/tabnet

# stdlib
from typing import List, Optional, Tuple

Expand Down