From 189fdf4c9b3bd66cc2c2d91e4ebafe2fb0bb1f00 Mon Sep 17 00:00:00 2001 From: Nisarg Shah Date: Mon, 23 Sep 2019 19:36:42 -0700 Subject: [PATCH] Remove vocab from cuda (#955) Summary: Pull Request resolved: https://github.com/facebookresearch/pytext/pull/955 We have users who can't train models on extremely large embeddings because we try to allocate space for that on the GPU. With this diff, in training, we add a flag which users can set explicitly to keep the embedding layer on CPU even when the model is getting trained on GPUs. This is not default because we need the user to know that there will be a cost associated moving the tensors on and off the GPU. Note that this only applies during training. Also note that this does not work in a multi-GPU environment because of the way the weights are synced via NCCL. Differential Revision: D17114398 fbshipit-source-id: e28b2981fbcbb248a6a704fd3c6e325fd45490e9 --- pytext/config/field_config.py | 1 + pytext/models/embeddings/word_embedding.py | 13 +++++++++++-- pytext/utils/torch.py | 13 +++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pytext/config/field_config.py b/pytext/config/field_config.py index d0c8f93ce..8e8c7ab76 100644 --- a/pytext/config/field_config.py +++ b/pytext/config/field_config.py @@ -40,6 +40,7 @@ class WordFeatConfig(ModuleConfig): min_freq: int = 1 mlp_layer_dims: Optional[List[int]] = [] padding_idx: Optional[int] = None + cpu_only: bool = False class DictFeatConfig(ModuleConfig): diff --git a/pytext/models/embeddings/word_embedding.py b/pytext/models/embeddings/word_embedding.py index 196412b3d..92ebb895b 100644 --- a/pytext/models/embeddings/word_embedding.py +++ b/pytext/models/embeddings/word_embedding.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import collections from typing import List, Optional import torch @@ -9,6 +8,7 @@ from pytext.data.tensorizers import Tensorizer from pytext.fields import FieldMeta from pytext.utils.embeddings import PretrainedEmbedding +from pytext.utils.torch import CPUOnlyParameter from tensorboardX import SummaryWriter from torch import nn @@ -96,6 +96,7 @@ def from_config( mlp_layer_dims=config.mlp_layer_dims, padding_idx=config.padding_idx, vocab=vocab, + cpu_only=config.cpu_only, ) def __init__( @@ -108,6 +109,7 @@ def __init__( mlp_layer_dims: List[int] = (), padding_idx: Optional[int] = None, vocab: Optional[List[str]] = None, + cpu_only: bool = False, ) -> None: output_embedding_dim = mlp_layer_dims[-1] if mlp_layer_dims else embedding_dim EmbeddingBase.__init__(self, embedding_dim=output_embedding_dim) @@ -119,6 +121,8 @@ def __init__( _weight=embeddings_weight, padding_idx=padding_idx, ) + if cpu_only: + self.word_embedding.weight = CPUOnlyParameter(self.word_embedding.weight) if embeddings_weight is None and init_range: self.word_embedding.weight.data.uniform_(init_range[0], init_range[1]) # Initialize unk embedding with zeros @@ -142,7 +146,12 @@ def __getattr__(self, name): return super().__getattr__(name) def forward(self, input): - return self.mlp(self.word_embedding(input)) + input_device = input.device + embedding_device = self.word_embedding.weight.device + if input_device != embedding_device: + input = input.to(embedding_device) + # We only want to do the embedding lookup on CPU + return self.mlp(self.word_embedding(input).to(input_device)) def freeze(self): for param in self.word_embedding.parameters(): diff --git a/pytext/utils/torch.py b/pytext/utils/torch.py index 9883a3e9f..64f76ea14 100644 --- a/pytext/utils/torch.py +++ b/pytext/utils/torch.py @@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Tuple import torch +from pytext.utils import cuda # ===== the following section should be replaced once JIT provide native support @@ -500,3 +501,15 @@ def package_for_inference(self): self.do_normalization = torch.jit.Attribute(self.do_normalization, bool) self.feature_avgs = torch.jit.Attribute(self.feature_avgs, List[float]) self.feature_stddevs = torch.jit.Attribute(self.feature_stddevs, List[float]) + + +class CPUOnlyParameter(torch.nn.Parameter): + def __init__(self): + assert ( + cuda.DISTRIBUTED_WORLD_SIZE <= 1 + ), "Multiple GPUs not supported for cpu_only embeddings" + super.__init__() + + def cuda(self, device=None): + # We do nothing because this Parameter should only be on the CPU + return self