diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 2b809db847..66e7c9e131 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -259,7 +259,7 @@ def __init__( # remove ddp comm hook variables from params dict del params_dict["comm_state"] del params_dict["comm_hook"] - module = module.to(device) + module = module.to(self.device) module = DDP(module, device_ids=device_ids, **params_dict) if torchdynamo_params: # TODO: Add support for dynamo and DDP @@ -295,7 +295,7 @@ def __init__( **asdict(strategy), ) else: - module = module.to(device) + module = module.to(self.device) self.module: torch.nn.Module = module