Skip to content

Commit

Permalink
Enable Support for Multi-GPU Training (#517)
Browse files Browse the repository at this point in the history
* Add `pickle_protocol` to `DataConfig` for passing
it down to `torch.save()` when caching datasets to disk

* Save the index of the original dataframe in `TabularDataset`
so that it can be restored when accessing `TabularDataset.data`

* Add `sync_dist=True` to all calls to self.log() in
`validation_step()` and `test_step()` to distributed training

* Fix `TrainerConfig.precision` to be a string and
remove integer choices. Add a pointer to docs with
possible options

* Add `sync_dist` to SSL Models

* Address `PerformanceWarning` related to `frame.insert()`

* Only load best checkpoint on rank zero in distributed
training

* if logging with wandb, `unwatch` the model after training

* address `FutureWarning` re inplace=True
  • Loading branch information
sorenmacbeth authored Dec 10, 2024
1 parent b504132 commit f04a05c
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/pytorch_tabular/categorical_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def transform(self, X):
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])

if self.handle_unseen == "impute":
X_encoded[col].fillna(self._imputed, inplace=True)
X_encoded[col] = X_encoded[col].fillna(self._imputed)
elif self.handle_unseen == "error":
if np.unique(X_encoded[col]).shape[0] > mapping.shape[0]:
raise ValueError(f"Unseen categories found in `{col}` column.")
Expand Down
18 changes: 12 additions & 6 deletions src/pytorch_tabular/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class DataConfig:
handle_missing_values (bool): Whether to handle missing values in categorical columns as
unknown
pickle_protocol (int): pickle protocol version passed to `torch.save` for dataset caching to disk
dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
Expand Down Expand Up @@ -179,6 +181,11 @@ class DataConfig:
metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
)

pickle_protocol: int = field(
default=2,
metadata={"help": "pickle protocol version passed to `torch.save` for dataset caching to disk"},
)

dataloader_kwargs: Dict[str, Any] = field(
default_factory=dict,
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},
Expand Down Expand Up @@ -351,8 +358,8 @@ class TrainerConfig:
progress_bar (str): Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`.
precision (int): Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`..
Choices are: [`32`,`16`,`64`].
precision (str): Precision of the model. Defaults to `32`. See
https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
seed (int): Seed for random number generators. Defaults to 42
Expand Down Expand Up @@ -536,11 +543,10 @@ class TrainerConfig:
default="rich",
metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."},
)
precision: int = field(
default=32,
precision: str = field(
default="32",
metadata={
"help": "Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`.",
"choices": [32, 16, 64],
"help": "Precision of the model. Defaults to `32`.",
},
)
seed: int = field(
Expand Down
16 changes: 11 additions & 5 deletions src/pytorch_tabular/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,21 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
if k in ret_value.keys():
logits_predictions[k].append(ret_value[k].detach().cpu())

logits_dfs = []
for k, v in logits_predictions.items():
v = torch.cat(v, dim=0).numpy()
if v.ndim == 1:
v = v.reshape(-1, 1)
for i in range(v.shape[-1]):
if v.shape[-1] > 1:
X_encoded[f"{k}_{i}"] = v[:, i]
else:
X_encoded[f"{k}"] = v[:, i]
if v.shape[-1] > 1:
temp_df = pd.DataFrame({f"{k}_{i}": v[:, i] for i in range(v.shape[-1])})
else:
temp_df = pd.DataFrame({f"{k}": v[:, 0]})

# Append the temp DataFrame to the list
logits_dfs.append(temp_df)

preds = pd.concat(logits_dfs, axis=1)
X_encoded = pd.concat([X_encoded, preds], axis=1)

if self.drop_original:
X_encoded.drop(columns=orig_features, inplace=True)
Expand Down
26 changes: 19 additions & 7 deletions src/pytorch_tabular/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,14 @@ def _setup_metrics(self):
else:
self.metrics = self.custom_metrics

def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tensor:
def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str, sync_dist: bool = False) -> torch.Tensor:
"""Calculates the loss for the model.
Args:
output (Dict): The output dictionary from the model
y (torch.Tensor): The target tensor
tag (str): The tag to use for logging
sync_dist (bool): enable distributed sync of logs
Returns:
torch.Tensor: The loss value
Expand All @@ -270,6 +271,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
on_step=False,
logger=True,
prog_bar=False,
sync_dist=sync_dist,
)
if self.hparams.task == "regression":
computed_loss = reg_loss
Expand All @@ -284,6 +286,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
on_step=False,
logger=True,
prog_bar=False,
sync_dist=sync_dist,
)
else:
# TODO loss fails with batch size of 1?
Expand All @@ -301,6 +304,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
on_step=False,
logger=True,
prog_bar=False,
sync_dist=sync_dist,
)
start_index = end_index
self.log(
Expand All @@ -311,10 +315,13 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
# on_step=False,
logger=True,
prog_bar=True,
sync_dist=sync_dist,
)
return computed_loss

def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> List[torch.Tensor]:
def calculate_metrics(
self, y: torch.Tensor, y_hat: torch.Tensor, tag: str, sync_dist: bool = False
) -> List[torch.Tensor]:
"""Calculates the metrics for the model.
Args:
Expand All @@ -324,6 +331,8 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
tag (str): The tag to use for logging
sync_dist (bool): enable distributed sync of logs
Returns:
List[torch.Tensor]: The list of metric values
Expand Down Expand Up @@ -356,6 +365,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
on_step=False,
logger=True,
prog_bar=False,
sync_dist=sync_dist,
)
_metrics.append(_metric)
avg_metric = torch.stack(_metrics, dim=0).sum()
Expand All @@ -379,6 +389,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
on_step=False,
logger=True,
prog_bar=False,
sync_dist=sync_dist,
)
_metrics.append(_metric)
start_index = end_index
Expand All @@ -391,6 +402,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
on_step=False,
logger=True,
prog_bar=True,
sync_dist=sync_dist,
)
return metrics

Expand Down Expand Up @@ -523,19 +535,19 @@ def validation_step(self, batch, batch_idx):
# fetched from the batch
y = batch["target"] if y is None else y
y_hat = output["logits"]
self.calculate_loss(output, y, tag="valid")
self.calculate_metrics(y, y_hat, tag="valid")
self.calculate_loss(output, y, tag="valid", sync_dist=True)
self.calculate_metrics(y, y_hat, tag="valid", sync_dist=True)
return y_hat, y

def test_step(self, batch, batch_idx):
with torch.no_grad():
output, y = self.forward_pass(batch)
# y is not None for SSL task.Rest of the tasks target is
# y is not None for SSL task. Rest of the tasks target is
# fetched from the batch
y = batch["target"] if y is None else y
y_hat = output["logits"]
self.calculate_loss(output, y, tag="test")
self.calculate_metrics(y, y_hat, tag="test")
self.calculate_loss(output, y, tag="test", sync_dist=True)
self.calculate_metrics(y, y_hat, tag="test", sync_dist=True)
return y_hat, y

def configure_optimizers(self):
Expand Down
12 changes: 6 additions & 6 deletions src/pytorch_tabular/ssl_models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ def _setup_metrics(self):
pass

@abstractmethod
def calculate_loss(self, output, tag):
def calculate_loss(self, output, tag, sync_dist):
pass

@abstractmethod
def calculate_metrics(self, output, tag):
def calculate_metrics(self, output, tag, sync_dist):
pass

@abstractmethod
Expand All @@ -167,15 +167,15 @@ def training_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx):
with torch.no_grad():
output = self.forward(batch)
self.calculate_loss(output, tag="valid")
self.calculate_metrics(output, tag="valid")
self.calculate_loss(output, tag="valid", sync_dist=True)
self.calculate_metrics(output, tag="valid", sync_dist=True)
return output

def test_step(self, batch, batch_idx):
with torch.no_grad():
output = self.forward(batch)
self.calculate_loss(output, tag="test")
self.calculate_metrics(output, tag="test")
self.calculate_loss(output, tag="test", sync_dist=True)
self.calculate_metrics(output, tag="test", sync_dist=True)
return output

def on_validation_epoch_end(self) -> None:
Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_tabular/ssl_models/dae/dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def forward(self, x: Dict):
else:
return z.features

def calculate_loss(self, output, tag):
def calculate_loss(self, output, tag, sync_dist=False):
total_loss = 0
for type_, out in output.items():
if type_ == "categorical":
Expand All @@ -220,6 +220,7 @@ def calculate_loss(self, output, tag):
on_step=False,
logger=True,
prog_bar=False,
sync_dist=sync_dist,
)
total_loss += loss
self.log(
Expand All @@ -230,10 +231,11 @@ def calculate_loss(self, output, tag):
# on_step=False,
logger=True,
prog_bar=True,
sync_dist=sync_dist,
)
return total_loss

def calculate_metrics(self, output, tag):
def calculate_metrics(self, output, tag, sync_dist=False):
pass

def featurize(self, x: Dict):
Expand Down
13 changes: 9 additions & 4 deletions src/pytorch_tabular/tabular_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
self.task = task
self.n = data.shape[0]
self.target = target
self.index = data.index
if target:
self.y = data[target].astype(np.float32).values
if isinstance(target, str):
Expand All @@ -87,11 +88,12 @@ def data(self):
data = pd.DataFrame(
np.concatenate([self.categorical_X, self.continuous_X], axis=1),
columns=self.categorical_cols + self.continuous_cols,
index=self.index,
)
elif self.continuous_cols:
data = pd.DataFrame(self.continuous_X, columns=self.continuous_cols)
data = pd.DataFrame(self.continuous_X, columns=self.continuous_cols, index=self.index)
elif self.categorical_cols:
data = pd.DataFrame(self.categorical_X, columns=self.categorical_cols)
data = pd.DataFrame(self.categorical_X, columns=self.categorical_cols, index=self.index)
else:
data = pd.DataFrame()
for i, t in enumerate(self.target):
Expand Down Expand Up @@ -474,6 +476,7 @@ def _cache_dataset(self):
target=self.target,
)
self.train = None

validation_dataset = TabularDataset(
task=self.config.task,
data=self.validation,
Expand All @@ -484,8 +487,10 @@ def _cache_dataset(self):
self.validation = None

if self.cache_mode is self.CACHE_MODES.DISK:
torch.save(train_dataset, self.cache_dir / "train_dataset")
torch.save(validation_dataset, self.cache_dir / "validation_dataset")
torch.save(train_dataset, self.cache_dir / "train_dataset", pickle_protocol=self.config.pickle_protocol)
torch.save(
validation_dataset, self.cache_dir / "validation_dataset", pickle_protocol=self.config.pickle_protocol
)
elif self.cache_mode is self.CACHE_MODES.MEMORY:
self.train_dataset = train_dataset
self.validation_dataset = validation_dataset
Expand Down
4 changes: 4 additions & 0 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities.model_summary import summarize
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from rich import print as rich_print
from rich.pretty import pprint
from sklearn.base import TransformerMixin
Expand Down Expand Up @@ -685,6 +686,8 @@ def train(
"/n" + "Original Error: " + oom_handler.oom_msg
)
self._is_fitted = True
if self.track_experiment and self.config.log_target == "wandb":
self.logger.experiment.unwatch(self.model)
if self.verbose:
logger.info("Training the model completed")
if self.config.load_best:
Expand Down Expand Up @@ -1522,6 +1525,7 @@ def add_noise(module, input, output):
)
return pred_df

@rank_zero_only
def load_best_model(self) -> None:
"""Loads the best model after training is done."""
if self.trainer.checkpoint_callback is not None:
Expand Down

0 comments on commit f04a05c

Please sign in to comment.