-
Notifications
You must be signed in to change notification settings - Fork 252
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add model num params display, gpu memory metrics (#56)
This PR is the start of adding perf related metrics. 1 - This PR adds function for logging the total num of unique model params, with option for only counting trainable params as well. (for future peft/qlora type work). 2 - logs it with comma formatted logging and model name ala: <img width="716" alt="Screenshot 2024-02-12 at 4 12 22 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/8eb48870-ab1e-4b70-9159-92864ff6c0e5"> this helps de-mistify for example the size of our debug model as well: <img width="716" alt="Screenshot 2024-02-12 at 4 10 17 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/77475306-54bc-48a6-bf28-9c9a542577fd"> **additional updates** - added in gpu mem tracking. We want to show the user peak memory stats, as well as monitor and alert for any cudacachealloc retries which are a perf hindrance. Thus, added class GPUMemoryMonitor: usage: 1 - instantiate <img width="1329" alt="Screenshot 2024-02-13 at 9 32 11 AM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/95610386-6fde-47bb-bbdc-bb7c399c5895"> 2 - start of training = start_monitoring() 3 - end of training = stop_monitoring() 4 - show results = get_peak_stats_str() and rank0_log it. <img width="1074" alt="Screenshot 2024-02-13 at 9 12 45 AM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/b6c7c854-7d83-436a-bea9-a67109422381"> [ghstack-poisoned]
- Loading branch information
Showing
3 changed files
with
202 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | ||
|
||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved | ||
|
||
from collections import namedtuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
_gb_in_bytes = 1024 * 1024 * 1024 | ||
_mb_in_bytes = 1024 * 1024 | ||
|
||
|
||
def format_to_gb(item, precision=4): | ||
"""quick function to format numbers to gigabyte and round to (default) 4 digit precision""" | ||
metric_num = item / _gb_in_bytes | ||
metric_num = round(metric_num, ndigits=precision) | ||
return metric_num | ||
|
||
|
||
def convert_to_gpu_pct(value, total_gpu_memory): | ||
return round(100 * (value / total_gpu_memory), 2) | ||
|
||
|
||
# named tuple for passing memory stats (as % of device capacity) for Tensorboard logging | ||
GPUMemStats = namedtuple( | ||
"GPUMemStats", | ||
[ | ||
"allocated_curr", | ||
"allocated_peak", | ||
"reserved_curr", | ||
"reserved_peak", | ||
"active_curr", | ||
"active_peak", | ||
"num_retries", | ||
], | ||
) | ||
|
||
|
||
class GPUMemoryMonitor: | ||
""" | ||
Class to monitor GPU memory usage | ||
""" | ||
|
||
def __init__(self, device: str = "cuda:0"): | ||
self.device = torch.device(device) # device object | ||
self.device_name = torch.cuda.get_device_name(self.device) | ||
self.device_index = torch.cuda.current_device() | ||
self.device_capacity = torch.cuda.get_device_properties( | ||
self.device | ||
).total_memory | ||
self.device_capacity_gb = format_to_gb(self.device_capacity) | ||
self.num_retries = 0 | ||
self.num_ooms = 0 | ||
self.peak_active_memory = 0 | ||
self.peak_allocated_memory = 0 | ||
self.peak_reserved_memory = 0 | ||
self.curr_reserved_memory = 0 | ||
|
||
self.device_reserved_memory_usage = 0 | ||
self.device_reserved_memory_gb = 0 | ||
self.device_reserved_memory_pct = 0 | ||
|
||
self.device_active_memory_usage = 0 | ||
self.device_active_memory_gb = 0 | ||
self.device_active_memory_pct = 0 | ||
|
||
# current stats | ||
self.device_alloc_memory_usage = torch.cuda.memory_allocated(self.device) | ||
self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage) | ||
self.device_alloc_memory_pct = convert_to_gpu_pct( | ||
self.device_alloc_memory_usage, self.device_capacity | ||
) | ||
|
||
# reset stats, clear cache | ||
torch.cuda.reset_peak_memory_stats() | ||
torch.cuda.empty_cache() | ||
|
||
def get_pct_memory(self, memory_num): | ||
pct_memory = memory_num / self.device_capacity | ||
pct_memory = round(100 * (pct_memory), 2) | ||
return pct_memory | ||
|
||
def get_gb_memory(self, memory_num): | ||
gb_memory = memory_num / _gb_in_bytes | ||
gb_memory = round(gb_memory, 2) | ||
return gb_memory | ||
|
||
def get_current_stats(self, return_data: bool = False): | ||
""" | ||
get the CudaCachingAllocator stats for the current device | ||
return_data: bool, if True, return the data as a named tuple | ||
""" | ||
curr_mem = torch.cuda.memory_stats(self.device) | ||
|
||
self.device_alloc_memory_usage = curr_mem["allocated_bytes.all.current"] | ||
self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage) | ||
self.device_alloc_memory_pct = convert_to_gpu_pct( | ||
self.device_alloc_memory_usage, self.device_capacity | ||
) | ||
|
||
self.device_reserved_memory_usage = curr_mem["reserved_bytes.all.current"] | ||
self.device_reserved_memory_gb = format_to_gb(self.device_reserved_memory_usage) | ||
self.device_reserved_memory_pct = convert_to_gpu_pct( | ||
self.device_reserved_memory_usage, self.device_capacity | ||
) | ||
|
||
self.device_active_memory_usage = curr_mem["active_bytes.all.current"] | ||
self.device_active_memory_gb = format_to_gb(self.device_active_memory_usage) | ||
self.device_active_memory_pct = convert_to_gpu_pct( | ||
self.device_active_memory_usage, self.device_capacity | ||
) | ||
|
||
display_str = "" | ||
display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%," | ||
display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n" | ||
|
||
self.get_peak_stats(curr_mem) | ||
|
||
peak_active_pct = self.get_pct_memory(self.peak_active_memory) | ||
peak_allocated_pct = self.get_pct_memory(self.peak_allocated_memory) | ||
peak_reserved_pct = self.get_pct_memory(self.peak_reserved_memory) | ||
display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n" | ||
|
||
display_str += f"num retries: {self.num_retries}, num ooms: {self.num_ooms}" | ||
if self.num_retries > 0: | ||
display_str += f"\nWARNING: {self.num_retries} retries -- recommend lowering batch size for max performance\n" | ||
|
||
if not return_data: | ||
return display_str | ||
|
||
# return named tuple | ||
curr_mem_stats = GPUMemStats( | ||
self.device_alloc_memory_pct, | ||
peak_active_pct, | ||
self.device_reserved_memory_pct, | ||
peak_reserved_pct, | ||
self.device_active_memory_pct, | ||
peak_active_pct, | ||
self.num_retries, | ||
) | ||
return curr_mem_stats | ||
|
||
def start_monitoring(self): | ||
"""reset all monitoring stats""" | ||
self.reset_peak_stats() | ||
|
||
def get_peak_stats(self, cuda_info=None): | ||
"""capture current peak memory stats""" | ||
if not cuda_info: | ||
cuda_info = torch.cuda.memory_stats() | ||
|
||
self.peak_active_memory = cuda_info.get("active_bytes.all.peak", 0) | ||
self.peak_allocated_memory = cuda_info.get("allocated_bytes.all.peak", 0) | ||
self.peak_reserved_memory = cuda_info.get("reserved_bytes.all.peak", 0) | ||
|
||
self.num_retries = cuda_info.get("num_alloc_retries", 0) | ||
self.num_ooms = cuda_info.get("num_ooms", 0) | ||
|
||
def reset_peak_stats(self): | ||
"""reset peak memory stats""" | ||
torch.cuda.reset_peak_memory_stats() | ||
torch.cuda.empty_cache() | ||
self.num_retries = 0 | ||
self.num_ooms = 0 | ||
self.active_peak_memory_utilization_str = "" | ||
self.peak_memory_utilization_str = "" | ||
self.peak_reserved_memory_utilization_str = "" | ||
|
||
def __str__(self): | ||
_ = self.get_current_stats() | ||
display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gb} GB capacity, " | ||
display_str += f"{self.device_alloc_memory_gb} GB in-use, {self.device_alloc_memory_pct}% in-use" | ||
return f"{display_str}" | ||
|
||
|
||
def get_num_params(model: nn.Module, only_trainable: bool = False) -> int: | ||
""" | ||
Get the total model params | ||
Args : only_trainable: whether to only count trainable params | ||
""" | ||
param_list = list(model.parameters()) | ||
if only_trainable: | ||
param_list = [p for p in param_list if p.requires_grad] | ||
unique_params = {p.data_ptr(): p for p in param_list}.values() | ||
return sum(p.numel() for p in unique_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters