-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Tabnet support #168
Add Tabnet support #168
Conversation
…P, update logging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@robsdavis please review as well.
Thanks for you work on this @tztsai! I'll do my review and look at merging, if you can get the tests to pass again. Thanks so much! |
@bcebere @robsdavis There is an error when installing dependencies. Could you help check it? |
I think the cause is the new Pip 23.1 released on April 15th. I kept its version to 23.0.1 and now it works fine. |
Thanks @tztsai for fixing that. Simply pinning the version of pip is not normally something I would like to do, but we have already pinned the version for a couple of dependencies (and have issues logged to remove these constraints). So, we can look at removing this constraint on pip at the same time. |
Description
I have added the TabNet model now. It is added as a building block in other models, similar to MLP. To incorporate the new model in current plugins like tabular_gan, tabddpm, etc., My approach is to remove the hyper-parameters specific to the MLP model and replace them with two parameters
model_type
andmodel_params
. I have created a file "factory.py" that contains functions that given a name string and a dict of params, will return an nn module, or an activation layer, or a feature encoder. Theget_model
function, provided withmodel_type
andmodel_params
, can instantiate an NN block such as anMLP
,TransformerModel
, orTabNet
. Consequently, by substituting MLP blocks with the dynamically instantiated blocks, we can provide more flexibility to the present plugins. I have tried usingTabNet
as the diffusion model of TabDDPM, instead ofMLP
, and it can successfully fit and generate data, although the performance is acutally not as good as MLP.Checklist