Skip to content

Commit

Permalink
add logging of energy consumption of entire training, add package for…
Browse files Browse the repository at this point in the history
… this to `pyproject.toml`
  • Loading branch information
ImahnShekhzadeh committed May 16, 2024
1 parent 3651dc5 commit 215c8f6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
20 changes: 18 additions & 2 deletions lstm_vision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from torch.cuda.amp import GradScaler
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from zeus.monitor import ZeusMonitor

from utils import (
end_timer_and_print,
log_training_stats,
print__batch_info,
save_checkpoint,
start_timer,
Expand Down Expand Up @@ -77,12 +78,22 @@ def train_and_validate(
reduction="mean", label_smoothing=label_smoothing
)

# auxiliary variables
start_time = start_timer(device=rank)
train_losses, val_losses, train_accs, val_accs = [], [], [], []
min_val_loss = float("inf")

# AMP
scaler = GradScaler(enabled=use_amp)

# measure energy consumption (rank 0 already measures energy consumption
# of all GPUs)
if rank in [0, torch.device("cuda:0"), torch.device("cuda")]:
monitor = ZeusMonitor(
gpu_indices=[i for i in range(torch.cuda.device_count())]
)
monitor.begin_window("training")

for epoch in range(num_epochs):
t0 = start_timer(device=rank)

Expand Down Expand Up @@ -172,12 +183,17 @@ def train_and_validate(
"%\n"
)

# stop energy consumption measurement
if rank in [0, torch.device("cuda:0"), torch.device("cuda")]:
measurement = monitor.end_window("training")

# number of iterations per device
num_iters = len(train_loader) * num_epochs

if rank in [0, torch.device("cpu")]:
end_timer_and_print(
log_training_stats(
start_time=start_time,
energy_consump=measurement.total_energy,
device=rank,
local_msg=(
f"Training {num_epochs} epochs ({num_iters} iterations)"
Expand Down
10 changes: 8 additions & 2 deletions lstm_vision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,19 @@ def start_timer(device: torch.device | int) -> float:
return perf_counter()


def end_timer_and_print(
start_time: float, device: torch.device | int, local_msg: str = ""
def log_training_stats(
start_time: float,
energy_consump: float,
device: torch.device | int,
local_msg: str = "",
) -> float:
"""
End the timer and print the time it took to execute the code as well as the
maximum memory used by tensors.
Args:
start_time: Time at which the training started.
energy_consump: Energy consumption of the entire training in Joules.
device: Device on which the code was executed. Can also be an int
representing the GPU ID.
local_msg: Local message to print.
Expand All @@ -412,6 +416,8 @@ def end_timer_and_print(
msg = f"{local_msg}\n\tTotal execution time = {time_diff:.3f} [sec]"
if device.type == "cuda":
msg += (
f"\n\tEnergy consumption of entire training = "
f"{energy_consump / 1e3:.3f} [kJ]"
f"\n\tMax memory used by tensors = "
f"{torch.cuda.max_memory_allocated(device=device) / 1024**2:.3f} "
"[MB]"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"python-dotenv~=1.0",
"matplotlib~=3.2",
"pre-commit",
"zeus-ml~=0.9"
]

[tool.isort]
Expand Down

0 comments on commit 215c8f6

Please sign in to comment.