diff --git a/pytext/config/field_config.py b/pytext/config/field_config.py index d0c8f93ce..8e8c7ab76 100644 --- a/pytext/config/field_config.py +++ b/pytext/config/field_config.py @@ -40,6 +40,7 @@ class WordFeatConfig(ModuleConfig): min_freq: int = 1 mlp_layer_dims: Optional[List[int]] = [] padding_idx: Optional[int] = None + cpu_only: bool = False class DictFeatConfig(ModuleConfig): diff --git a/pytext/models/embeddings/word_embedding.py b/pytext/models/embeddings/word_embedding.py index 196412b3d..92ebb895b 100644 --- a/pytext/models/embeddings/word_embedding.py +++ b/pytext/models/embeddings/word_embedding.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import collections from typing import List, Optional import torch @@ -9,6 +8,7 @@ from pytext.data.tensorizers import Tensorizer from pytext.fields import FieldMeta from pytext.utils.embeddings import PretrainedEmbedding +from pytext.utils.torch import CPUOnlyParameter from tensorboardX import SummaryWriter from torch import nn @@ -96,6 +96,7 @@ def from_config( mlp_layer_dims=config.mlp_layer_dims, padding_idx=config.padding_idx, vocab=vocab, + cpu_only=config.cpu_only, ) def __init__( @@ -108,6 +109,7 @@ def __init__( mlp_layer_dims: List[int] = (), padding_idx: Optional[int] = None, vocab: Optional[List[str]] = None, + cpu_only: bool = False, ) -> None: output_embedding_dim = mlp_layer_dims[-1] if mlp_layer_dims else embedding_dim EmbeddingBase.__init__(self, embedding_dim=output_embedding_dim) @@ -119,6 +121,8 @@ def __init__( _weight=embeddings_weight, padding_idx=padding_idx, ) + if cpu_only: + self.word_embedding.weight = CPUOnlyParameter(self.word_embedding.weight) if embeddings_weight is None and init_range: self.word_embedding.weight.data.uniform_(init_range[0], init_range[1]) # Initialize unk embedding with zeros @@ -142,7 +146,12 @@ def __getattr__(self, name): return super().__getattr__(name) def forward(self, input): - return self.mlp(self.word_embedding(input)) + input_device = input.device + embedding_device = self.word_embedding.weight.device + if input_device != embedding_device: + input = input.to(embedding_device) + # We only want to do the embedding lookup on CPU + return self.mlp(self.word_embedding(input).to(input_device)) def freeze(self): for param in self.word_embedding.parameters(): diff --git a/pytext/utils/torch.py b/pytext/utils/torch.py index 9883a3e9f..64f76ea14 100644 --- a/pytext/utils/torch.py +++ b/pytext/utils/torch.py @@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Tuple import torch +from pytext.utils import cuda # ===== the following section should be replaced once JIT provide native support @@ -500,3 +501,15 @@ def package_for_inference(self): self.do_normalization = torch.jit.Attribute(self.do_normalization, bool) self.feature_avgs = torch.jit.Attribute(self.feature_avgs, List[float]) self.feature_stddevs = torch.jit.Attribute(self.feature_stddevs, List[float]) + + +class CPUOnlyParameter(torch.nn.Parameter): + def __init__(self): + assert ( + cuda.DISTRIBUTED_WORLD_SIZE <= 1 + ), "Multiple GPUs not supported for cpu_only embeddings" + super.__init__() + + def cuda(self, device=None): + # We do nothing because this Parameter should only be on the CPU + return self