Skip to content

Commit

Permalink
simplify: train/val funcs directly return train/val accuracies
Browse files Browse the repository at this point in the history
  • Loading branch information
ImahnShekhzadeh committed May 19, 2024
1 parent 4c5789f commit 3caa220
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions lstm_vision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def train_and_validate(
for epoch in range(num_epochs):
t0 = start_timer(device=rank)

train_loss, num_correct, num_samples = train_one_epoch(
train_loss, train_acc = train_one_epoch(
train_loader=train_loader,
model=model,
optimizer=optimizer,
Expand All @@ -114,9 +114,8 @@ def train_and_validate(
# mean loss per sample over all GPUs; we could alternatively use
# `torch.distributed.reduce()` to sum the losses over all GPUs
train_loss *= world_size / len(train_loader.dataset)
train_acc = num_correct / num_samples

val_loss, num_correct, num_samples = validate_one_epoch(
val_loss, val_acc = validate_one_epoch(
model=model,
val_loader=val_loader,
loss_fn=loss_fn,
Expand All @@ -126,7 +125,6 @@ def train_and_validate(
freq_output__val=freq_output__val,
)
val_loss *= world_size / len(val_loader.dataset)
val_acc = num_correct / num_samples

# update checkpoint dict if val loss has decreased
if val_loss < min_val_loss:
Expand Down Expand Up @@ -227,7 +225,7 @@ def train_one_epoch(
epoch: int,
max_norm: Optional[float] = None,
freq_output__train: Optional[int] = 10,
) -> Tuple[List[float], int, int]:
) -> Tuple[float, float]:
"""
Train model for one epoch.
Expand All @@ -245,8 +243,8 @@ def train_one_epoch(
freq_output__train: Frequency at which to print the training info.
Returns:
Training loss for all batches for the single epoch, number of correct
predictions, and number of samples.
Summsed training loss for all batches for the single epoch, train
accuracy on rank 0 or CPU.
"""

epoch_loss, num_correct, num_samples = 0, 0, 0 # auxiliary variables
Expand Down Expand Up @@ -295,7 +293,7 @@ def train_one_epoch(
frequency=freq_output__train,
)

return epoch_loss, num_correct, num_samples
return epoch_loss, num_correct / num_samples


@torch.no_grad()
Expand All @@ -307,7 +305,7 @@ def validate_one_epoch(
use_amp: bool,
epoch: int,
freq_output__val: Optional[int] = 10,
) -> Tuple[List[float], int, int]:
) -> Tuple[float, float]:
"""
Validate model for one epoch.
Expand All @@ -321,8 +319,8 @@ def validate_one_epoch(
freq_output__val: Frequency at which to print the validation info.
Returns:
Validation loss for all batches for the single epoch, number of correct
predictions, and number of samples.
Summsed training loss for all batches for the single epoch, train
accuracy on rank 0 or CPU.
"""
epoch_loss, val_num_correct, val_num_samples = 0, 0, 0
model.eval()
Expand Down Expand Up @@ -356,4 +354,4 @@ def validate_one_epoch(
frequency=freq_output__val,
)

return epoch_loss, val_num_correct, val_num_samples
return epoch_loss, val_num_correct / val_num_samples

0 comments on commit 3caa220

Please sign in to comment.