Skip to content

Commit 8a15419

Browse files
snisargfacebook-github-bot
authored andcommitted
Remove vocab from cuda (facebookresearch#955)
Summary: Pull Request resolved: facebookresearch#955 We have users who can't train models on extremely large embeddings because we try to allocate space for that on the GPU. With this diff, in training, we add a flag which users can set explicitly to keep the embedding layer on CPU even when the model is getting trained on GPUs. This is not default because we need the user to know that there will be a cost associated moving the tensors on and off the GPU. Note that this only applies during training. Also note that this does not work in a multi-GPU environment because of the way the weights are synced via NCCL. Differential Revision: D17114398 fbshipit-source-id: 840f37f77c70089137f2cf23a262dc503e5e2080
1 parent c7dd752 commit 8a15419

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

Diff for: pytext/config/field_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class WordFeatConfig(ModuleConfig):
4040
min_freq: int = 1
4141
mlp_layer_dims: Optional[List[int]] = []
4242
padding_idx: Optional[int] = None
43+
cpu_only: bool = False
4344

4445

4546
class DictFeatConfig(ModuleConfig):

Diff for: pytext/models/embeddings/word_embedding.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33

4-
import collections
54
from typing import List, Optional
65

76
import torch
87
from pytext.config.field_config import WordFeatConfig
98
from pytext.data.tensorizers import Tensorizer
109
from pytext.fields import FieldMeta
1110
from pytext.utils.embeddings import PretrainedEmbedding
11+
from pytext.utils.torch import CPUOnlyParameter
1212
from tensorboardX import SummaryWriter
1313
from torch import nn
1414

@@ -96,6 +96,7 @@ def from_config(
9696
mlp_layer_dims=config.mlp_layer_dims,
9797
padding_idx=config.padding_idx,
9898
vocab=vocab,
99+
cpu_only=config.cpu_only,
99100
)
100101

101102
def __init__(
@@ -108,6 +109,7 @@ def __init__(
108109
mlp_layer_dims: List[int] = (),
109110
padding_idx: Optional[int] = None,
110111
vocab: Optional[List[str]] = None,
112+
cpu_only: bool = False,
111113
) -> None:
112114
output_embedding_dim = mlp_layer_dims[-1] if mlp_layer_dims else embedding_dim
113115
EmbeddingBase.__init__(self, embedding_dim=output_embedding_dim)
@@ -119,6 +121,8 @@ def __init__(
119121
_weight=embeddings_weight,
120122
padding_idx=padding_idx,
121123
)
124+
if cpu_only:
125+
self.word_embedding.weight = CPUOnlyParameter(self.word_embedding.weight)
122126
if embeddings_weight is None and init_range:
123127
self.word_embedding.weight.data.uniform_(init_range[0], init_range[1])
124128
# Initialize unk embedding with zeros
@@ -142,7 +146,12 @@ def __getattr__(self, name):
142146
return super().__getattr__(name)
143147

144148
def forward(self, input):
145-
return self.mlp(self.word_embedding(input))
149+
input_device = input.device
150+
embedding_device = self.word_embedding.weight.device
151+
if input_device != embedding_device:
152+
input = input.to(embedding_device)
153+
# We only want to do the embedding lookup on CPU
154+
return self.mlp(self.word_embedding(input).to(input_device))
146155

147156
def freeze(self):
148157
for param in self.word_embedding.parameters():

Diff for: pytext/utils/torch.py

+13
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict, List, Optional, Tuple
66

77
import torch
8+
from pytext.utils import cuda
89

910

1011
# ===== the following section should be replaced once JIT provide native support
@@ -500,3 +501,15 @@ def package_for_inference(self):
500501
self.do_normalization = torch.jit.Attribute(self.do_normalization, bool)
501502
self.feature_avgs = torch.jit.Attribute(self.feature_avgs, List[float])
502503
self.feature_stddevs = torch.jit.Attribute(self.feature_stddevs, List[float])
504+
505+
506+
class CPUOnlyParameter(torch.nn.Parameter):
507+
def __init__(self):
508+
assert (
509+
cuda.DISTRIBUTED_WORLD_SIZE <= 1
510+
), "Multiple GPUs not supported for cpu_only embeddings"
511+
super.__init__()
512+
513+
def cuda(self, device=None):
514+
# We do nothing because this Parameter should only be on the CPU
515+
return self

0 commit comments

Comments
 (0)