Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Token classification inference on CPU #557

Closed
lebionick opened this issue Apr 6, 2020 · 2 comments
Closed

Token classification inference on CPU #557

lebionick opened this issue Apr 6, 2020 · 2 comments
Assignees

Comments

@lebionick
Copy link

lebionick commented Apr 6, 2020

Hello,
I've trained BERT for NER and want to inference it on CPU. I used code from here: https://github.com/NVIDIA/NeMo/blob/master/examples/nlp/token_classification/token_classification_infer.py and made this class:

class Inference:
    def __init__(self, tokenizer, model, use_gpu=True):
        placement = nemo.core.DeviceType.GPU if use_gpu else nemo.core.DeviceType.CPU
        self.nf = nemo.core.NeuralModuleFactory(
            backend=nemo.core.Backend.PyTorch,
            placement=placement)
        self.tokenizer = tokenizer
        self.model = model
        self.labels_dict = {0: 'O', 1: 'BRND', 2: 'EINF', 3: 'LINE', 4: 'NAME', 5: 'PKGT', 6: 'VOL'}
        self.hidden_size = model.hidden_size
        self.classifier = TokenClassifier(hidden_size=self.hidden_size,
                                          num_classes=len(self.labels_dict))
        if use_gpu:
            self.model = self.model.cuda()
            self.classifier = self.classifier.cuda()
        else:
            self.model = self.model.cpu()
            self.classifier = self.classifier.cpu()
        self.none_label = 'O'
    
    def forward(self, texts, batch_size=1, add_brackets=True):
        data_layer = nemo_nlp.nm.data_layers.BertTokenClassificationInferDataLayer(
            queries=texts, tokenizer=self.tokenizer, max_seq_length=32, batch_size=batch_size
        )
        input_ids, input_type_ids, input_mask, _, subtokens_mask = data_layer()
        hidden_states = self.model(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)
        logits = self.classifier(hidden_states=hidden_states)
        evaluated_tensors = self.nf.infer(tensors=[logits, subtokens_mask], checkpoint_dir="/data/weights/ner")
        logits, subtokens_mask = [self.concatenate(tensors) for tensors in evaluated_tensors]
        preds = np.argmax(logits, axis=2)
        return preds  # I omitted post-processing

where model is nemo_nlp.nm.trainables.huggingface.BERT, tokenizer is nemo_nlp.data.SentencePieceTokenizer, in "/data/weights/ner" I have checkpoints from training in .pt format.

When I run forward with use_gpu=True it works, but with use_gpu=False it crashes with error:
RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_mm
in line:
---> 28 evaluated_tensors = self.nf.infer(tensors=[logits, subtokens_mask], checkpoint_dir="/data/weights/ner")

What I'm doing wrong and how to fix this? Does it come from loading checkpoint, which was made with CUDA tensors?

@ekmb ekmb self-assigned this Apr 6, 2020
@ekmb
Copy link
Collaborator

ekmb commented Apr 7, 2020

@lebionick I couldn't reproduce your error, could you share code with model and tokenizer initialization?
This worked for me:

`import nemo
import nemo.collections.nlp as nemo_nlp
from nemo.collections.nlp.nm.trainables import TokenClassifier
import numpy as np

class Inference:
def init(self, use_gpu=True):
placement = nemo.core.DeviceType.GPU if use_gpu else nemo.core.DeviceType.CPU
self.nf = nemo.core.NeuralModuleFactory(
backend=nemo.core.Backend.PyTorch,
placement=placement)
self.model = nemo_nlp.nm.trainables.huggingface.BERT(pretrained_model_name="bert-base-uncased")
self.tokenizer = nemo.collections.nlp.data.tokenizers.get_tokenizer(
tokenizer_name="sentencepiece",
pretrained_model_name="bert-base-uncased",
tokenizer_model='tokenizer.model',
)
self.checkpoint_dir = 'checkpoints'
self.labels_dict = {0: 'O', 1: 'BRND', 2: 'EINF', 3: 'LINE', 4: 'NAME', 5: 'PKGT', 6: 'VOL'}
self.hidden_size = self.model.hidden_size
self.classifier = TokenClassifier(hidden_size=self.hidden_size,
num_classes=len(self.labels_dict))
if use_gpu:
self.model = self.model.cuda()
self.classifier = self.classifier.cuda()
else:
self.model = self.model.cpu()
self.classifier = self.classifier.cpu()
self.none_label = 'O'

def forward(self, texts, batch_size=1, add_brackets=True):
    data_layer = nemo_nlp.nm.data_layers.BertTokenClassificationInferDataLayer(
        queries=texts, tokenizer=self.tokenizer, max_seq_length=32, batch_size=batch_size
    )
    input_ids, input_type_ids, input_mask, _, subtokens_mask = data_layer()
    hidden_states = self.model(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)
    logits = self.classifier(hidden_states=hidden_states)
    evaluated_tensors = self.nf.infer(tensors=[logits, subtokens_mask], checkpoint_dir=self.checkpoint_dir)
    return 'ok'

inf = Inference(False)
inf.forward(['Example #1', 'Example #2'])`

@lebionick
Copy link
Author

@ekmb I started to do minimal working example and found the cause: I load BERT earlier in the code (jupyter cell) using this:

bert_model = nemo_nlp.nm.trainables.huggingface.BERT(
    config_filename="/data/mwe/checkpoints/bert-config.json"
)

at this moment I had global nemo.core.NeuralModuleFactory with placement on GPU (that I actually used for NER training).
So I fixed this issue, by loading model within __init__ after creating NeuralModuleFactory
Could you please tell, how does it work, where nemo_nlp.nm.trainables.huggingface.BERT tries to find NeuralModuleFactory?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants