Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt tokenization bugfix #4197

Merged
merged 5 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ model:
seed: 1234
nemo_path: ${name}.nemo # .nemo filename and path to where the model and associated artifacts should be saved
lm_finetune: False # whether fine tune the language model
pseudo_token_base: 'PROMPT_' # pseudo prompt tokens
virtual_prompt_style: 'p-tuning' # one of 'prompt-tuning', 'p-tuning', or 'inference'
encoder_seq_length: 2048
tensor_model_parallel_size: 1 # intra-layer model parallelism
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _insert_text_in_template(self, input_example, prompt_template_fields, doc):
return input_example

def _insert_virtual_token_placeholders(self, input_example, virtual_token_splits):
""" Insert the correct number of pseudo tokens at the <|virtual_PROMPT_n|> markers """
""" Insert the correct number of pseudo tokens at the <|VIRTUAL_PROMPT_n|> markers """
total_inserted_tokens = 0

for idx in range(len(virtual_token_splits)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
from nemo.collections.nlp.data.language_modeling.megatron.gpt_prompt_learning_dataset import GPTPromptLearningDataset
from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common import PromptEncoder, PromptTable, VirtualPromptSource, VirtualPromptStyle
from nemo.collections.nlp.modules.common import (
PromptEncoder,
PromptTable,
VirtualPromptPlaceholderToken,
VirtualPromptSource,
VirtualPromptStyle,
)
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.collections.nlp.modules.common.text_generation_utils import (
get_default_length_params,
Expand Down Expand Up @@ -104,8 +110,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
)

# Prepare pseudo token ids for virtual/virtual prompt tokens
self.pseudo_token_base = cfg.pseudo_token_base
self.pseudo_tokens = [self.pseudo_token_base + str(i) for i in range(self.max_virtual_tokens)]
self.pseudo_tokens = [
VirtualPromptPlaceholderToken.BASE.value + str(i) + VirtualPromptPlaceholderToken.END.value
for i in range(self.max_virtual_tokens)
]
self.tokenizer.add_special_tokens({'additional_special_tokens': self.pseudo_tokens})
self.pseudo_token_ids = self.tokenizer.tokens_to_ids(self.pseudo_tokens)
self.pseudo_token_ids_start = self.pseudo_token_ids[0]
Expand Down
7 changes: 6 additions & 1 deletion nemo/collections/nlp/modules/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
)
from nemo.collections.nlp.modules.common.lm_utils import get_lm_model, get_pretrained_lm_models_list
from nemo.collections.nlp.modules.common.prompt_encoder import PromptEncoder
from nemo.collections.nlp.modules.common.prompt_table import PromptTable, VirtualPromptSource, VirtualPromptStyle
from nemo.collections.nlp.modules.common.prompt_table import (
PromptTable,
VirtualPromptPlaceholderToken,
VirtualPromptSource,
VirtualPromptStyle,
)
from nemo.collections.nlp.modules.common.sequence_classifier import SequenceClassifier
from nemo.collections.nlp.modules.common.sequence_regression import SequenceRegression
from nemo.collections.nlp.modules.common.sequence_token_classifier import SequenceTokenClassifier
Expand Down
7 changes: 6 additions & 1 deletion nemo/collections/nlp/modules/common/prompt_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

__all__ = ['PromptTable', 'VirtualPromptSource', 'VirtualPromptStyle']
__all__ = ['PromptTable', 'VirtualPromptSource', 'VirtualPromptStyle', 'VirtualPromptPlaceholderToken']


class VirtualPromptStyle(enum.Enum):
Expand All @@ -42,6 +42,11 @@ class VirtualPromptSource(enum.Enum):
PROMPT_ENCODER = 'prompt_encoder'


class VirtualPromptPlaceholderToken(enum.Enum):
BASE = '<prompt_'
END = '>'


class PromptTable(NeuralModule, Exportable):
def __init__(self, existing_tasks, task_templates, task_id_num_to_name, hidden_size):
super().__init__()
Expand Down
28 changes: 17 additions & 11 deletions tests/collections/nlp/test_prompt_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch

from nemo.collections.nlp.data.language_modeling.megatron.gpt_prompt_learning_dataset import GPTPromptLearningDataset
from nemo.collections.nlp.modules.common import VirtualPromptSource
from nemo.collections.nlp.modules.common import VirtualPromptPlaceholderToken, VirtualPromptSource
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.core import Dataset

Expand Down Expand Up @@ -61,7 +61,7 @@ def create_temp_dataset():
def get_task_templates():
task_templates = {}
task_templates['task name A'] = {
"prompt_template": "<|VIRTUAL_PROMPT_0|>{text}{answer}",
"prompt_template": "<|VIRTUAL_PROMPT_0|> {text}{answer}",
"prompt_template_fields": ['text', 'answer'],
"total_virtual_tokens": 5,
"virtual_token_splits": [5],
Expand All @@ -71,8 +71,8 @@ def get_task_templates():
"task_id_num": 0,
}
task_templates['task name B'] = {
"prompt_template": "<|VIRTUAL_PROMPT_0|>{question}<|VIRTUAL_PROMPT_1|>{answer}{extra}",
"prompt_template_fields": ['question', 'answer'],
"prompt_template": "<|VIRTUAL_PROMPT_0|> {question} <|VIRTUAL_PROMPT_1|> {answer}{extra}",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why you add the space around the virtual tokens? If space is needed, virtual prompt should learn to be a space embedding?

Copy link
Contributor Author

@vadam5 vadam5 May 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was playing around with the unit tests to help diagnose a bug with the huggingface tokenizer. This change isn't needed, I just forgot to put it back. I removed the spaces so they should be back the way they were.

"prompt_template_fields": ['question', 'answer', 'extra'],
"total_virtual_tokens": 10,
"virtual_token_splits": [7, 3],
"truncate_field": None,
Expand All @@ -83,6 +83,15 @@ def get_task_templates():
return task_templates


def get_pseudo_tokens(total_virtual_tokens):
pseudo_tokens = [
VirtualPromptPlaceholderToken.BASE.value + str(i) + VirtualPromptPlaceholderToken.END.value
for i in range(total_virtual_tokens)
]

return pseudo_tokens

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you have the similar logics in the prompt_lenaring_model file. Put the function there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can do that


class TestMegatronGPTPromptLearningDataset:
@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -92,9 +101,8 @@ def test_init_prompt_learning_dataset(self):
dataset_path = create_temp_dataset()

# Setup virtual token place holders
pseudo_token_base = 'PROMPT_'
max_virtual_tokens = 10
pseudo_tokens = [pseudo_token_base + str(i) for i in range(max_virtual_tokens)]
total_virtual_tokens = 10
pseudo_tokens = get_pseudo_tokens(total_virtual_tokens)
tokenizer.add_special_tokens({'additional_special_tokens': pseudo_tokens})

dataset = get_prompt_tuning_dataset(
Expand All @@ -119,9 +127,8 @@ def test_prompt_learning_dataset_collate_fn_prompt_table(self):
dataset_path = create_temp_dataset()

# Setup virtual token place holders
pseudo_token_base = 'PROMPT_'
total_virtual_tokens = 10
pseudo_tokens = [pseudo_token_base + str(i) for i in range(total_virtual_tokens)]
pseudo_tokens = get_pseudo_tokens(total_virtual_tokens)
tokenizer.add_special_tokens({'additional_special_tokens': pseudo_tokens})

dataset = get_prompt_tuning_dataset(
Expand Down Expand Up @@ -153,9 +160,8 @@ def test_prompt_learning_dataset_collate_fn_prompt_encoder(self):
dataset_path = create_temp_dataset()

# Setup virtual token place holders
pseudo_token_base = 'PROMPT_'
total_virtual_tokens = 10
pseudo_tokens = [pseudo_token_base + str(i) for i in range(total_virtual_tokens)]
pseudo_tokens = get_pseudo_tokens(total_virtual_tokens)
tokenizer.add_special_tokens({'additional_special_tokens': pseudo_tokens})

dataset = get_prompt_tuning_dataset(
Expand Down