Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix module to device in AutoUnit #398

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def __init__(
# remove ddp comm hook variables from params dict
del params_dict["comm_state"]
del params_dict["comm_hook"]
module = module.to(device)
module = module.to(self.device)
module = DDP(module, device_ids=device_ids, **params_dict)
if torchdynamo_params:
# TODO: Add support for dynamo and DDP
Expand Down Expand Up @@ -295,7 +295,7 @@ def __init__(
**asdict(strategy),
)
else:
module = module.to(device)
module = module.to(self.device)

self.module: torch.nn.Module = module

Expand Down
2 changes: 2 additions & 0 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .device import (
copy_data_to_device,
copy_list_tensors_to_device,
CPUStats,
get_device_from_env,
get_nvidia_smi_gpu_stats,
Expand Down Expand Up @@ -56,6 +57,7 @@

__all__ = [
"copy_data_to_device",
"copy_list_tensors_to_device",
"CPUStats",
"get_device_from_env",
"get_nvidia_smi_gpu_stats",
Expand Down
31 changes: 30 additions & 1 deletion torchtnt/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import subprocess
from collections import defaultdict
from dataclasses import fields, is_dataclass
from typing import Any, Dict, Mapping, TypeVar
from typing import Any, Dict, List, Mapping, TypeVar

import torch
from torchtnt.utils.version import is_torch_version_geq_1_12
Expand Down Expand Up @@ -123,6 +123,35 @@ def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any
return data


def copy_list_tensors_to_device(
datae: List[torch.Tensor], device: torch.device, *args: Any, **kwargs: Any
) -> List[torch.Tensor]:
"""Function that recursively a list of Tensors to a torch.device. This has better performance than copy_data_to_device.

Args:
data: The data to copy to device
device: The device to which the data should be copied
args: positional arguments that will be passed to the `to` call
kwargs: keyword arguments that will be passed to the `to` call

Returns:
The data on the correct device
"""

return [
d.view_as(data.size())
for (d, data) in zip(
torch.split_with_sizes(
torch.cat([data.reshape(-1) for data in datae]).to(
device, *args, **kwargs
),
[data.numel() for data in datae],
),
datae,
)
]


def record_data_in_stream(data: T, stream: torch.cuda.streams.Stream) -> None:
"""
As mentioned in
Expand Down