Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into bugfix/instance_segmentation_example
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Dec 14, 2021
2 parents 0f49513 + 70fcff6 commit 2257e53
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug when not explicitly passing `embedding_sizes` to the `TabularClassifier` and `TabularRegressor` tasks ([#1067](https://github.com/PyTorchLightning/lightning-flash/pull/1067))

### Removed

## [0.6.0] - 2021-13-12
Expand Down
11 changes: 8 additions & 3 deletions flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class TabularClassifier(ClassificationTask):
Args:
num_features: Number of columns in table (not including target column).
num_classes: Number of classes to classify.
embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings.
embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings (or ``None`` if there are no
categorical fields in the data).
loss_fn: Loss function for training, defaults to cross entropy.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
Expand All @@ -63,7 +64,7 @@ def __init__(
self,
num_features: int,
num_classes: int,
embedding_sizes: List[Tuple[int, int]] = None,
embedding_sizes: Optional[List[Tuple[int, int]]] = None,
loss_fn: Callable = F.cross_entropy,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
Expand All @@ -75,7 +76,11 @@ def __init__(
):
self.save_hyperparameters()

cat_dims, cat_emb_dim = zip(*embedding_sizes) if embedding_sizes else ([], [])
if embedding_sizes:
cat_dims, cat_emb_dim = zip(*embedding_sizes)
else:
cat_dims, cat_emb_dim, embedding_sizes = [], [], []

model = TabNet(
input_dim=num_features,
output_dim=num_classes,
Expand Down
11 changes: 8 additions & 3 deletions flash/tabular/regression/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class TabularRegressor(RegressionTask):
Args:
num_features: Number of columns in table (not including target column).
embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings.
embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings (or ``None`` if there are no
categorical fields in the data).
loss_fn: Loss function for training, defaults to cross entropy.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
Expand All @@ -60,7 +61,7 @@ class TabularRegressor(RegressionTask):
def __init__(
self,
num_features: int,
embedding_sizes: List[Tuple[int, int]] = None,
embedding_sizes: Optional[List[Tuple[int, int]]] = None,
loss_fn: Callable = F.mse_loss,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
Expand All @@ -71,7 +72,11 @@ def __init__(
):
self.save_hyperparameters()

cat_dims, cat_emb_dim = zip(*embedding_sizes) if embedding_sizes else ([], [])
if embedding_sizes:
cat_dims, cat_emb_dim = zip(*embedding_sizes)
else:
cat_dims, cat_emb_dim, embedding_sizes = [], [], []

model = TabNet(
input_dim=num_features,
output_dim=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/tabular/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_init_train_no_num(tmpdir):
@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.")
def test_init_train_no_cat(tmpdir):
train_dl = torch.utils.data.DataLoader(DummyDataset(num_cat=0), batch_size=16)
model = TabularClassifier(num_classes=10, num_features=16, embedding_sizes=[])
model = TabularClassifier(num_classes=10, num_features=16)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, train_dl)

Expand Down

0 comments on commit 2257e53

Please sign in to comment.