Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt Learning Notebook Bug Fix #4689

Merged
merged 5 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def main(cfg) -> None:
max_input_length = model.frozen_model.cfg.encoder_seq_length - length_params["max_length"]

_, dataloader = model.build_virtual_prompt_dataset(
dataset_paths=cfg.data_paths,
data=cfg.data_paths,
batch_size=64,
max_seq_length=max_input_length,
min_seq_length=model.cfg.data.get('min_seq_length', 1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class GPTPromptLearningDataset(Dataset):
The dataset class for prompt-tuning or p-tuning pretrained GPT models.

Args:
dataset_paths (list[strings]): paths to .jsonl or .json files
data (list[strings], list[dicts]): (1) paths to .jsonl or .json files, (2) dict objects corresponding to each input example
tokenizer (tokenizer): Tokenizer from frozen language model
virtual_prompt_source (Enum): Either VirtualPromptSource.PROMPT_TABLE or VirtualPromptSource.PROMPT_ENCODER
task_templates (dict): Dictionary containing all task template information needed to format prompts. Created in the GPTPromptLearningModel class.
Expand All @@ -46,7 +46,7 @@ class GPTPromptLearningDataset(Dataset):

def __init__(
self,
dataset_paths,
data,
tokenizer,
virtual_prompt_source: VirtualPromptSource,
task_templates: dict,
Expand Down Expand Up @@ -80,13 +80,17 @@ def __init__(

logging.info("Loading and tokenizing dataset ... ")

# Data is just a list of dicts already loaded from a json file or passed in directly as a dict
if isinstance(data[0], dict):
self.load_data(data)

# Datasets are a list of file path strings to .json or .jsonl files
if isinstance(dataset_paths[0], str):
for path in dataset_paths:
elif isinstance(data[0], str):
for path in data:
dataset = open(path, 'r', encoding='utf-8')
self.load_data(dataset)
else:
raise ValueError("Datasets must be a list of filepath strings")
raise ValueError("Datasets must be a list of filepath strings or a list of data example dicts")

def load_data(self, dataset):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def setup(self, stage=None):
def setup_training_data(self, training_data_config=None):
if self.cfg.data.get('train_ds', None):
self._train_ds, self._train_dl = self.build_virtual_prompt_dataset(
dataset_paths=self.cfg.data.train_ds,
data=self.cfg.data.train_ds,
batch_size=self.cfg.global_batch_size,
max_seq_length=self.frozen_model.cfg.encoder_seq_length,
min_seq_length=self.cfg.data.get('min_seq_length', 1),
Expand All @@ -702,7 +702,7 @@ def setup_training_data(self, training_data_config=None):
def setup_validation_data(self, validation_data_config=None):
if self.cfg.data.get('validation_ds', None):
self._validation_ds, self._validation_dl = self.build_virtual_prompt_dataset(
dataset_paths=self.cfg.data.validation_ds,
data=self.cfg.data.validation_ds,
batch_size=self.cfg.global_batch_size,
max_seq_length=self.frozen_model.cfg.encoder_seq_length,
min_seq_length=self.cfg.data.get('min_seq_length', 1),
Expand All @@ -718,7 +718,7 @@ def setup_validation_data(self, validation_data_config=None):
def setup_test_data(self, test_data_config=None):
if self.cfg.data.get('test_ds', None):
self._test_ds, self._test_dl = self.build_virtual_prompt_dataset(
dataset_paths=self.cfg.data.test_ds,
data=self.cfg.data.test_ds,
batch_size=self.cfg.global_batch_size,
max_seq_length=self.frozen_model.cfg.encoder_seq_length,
min_seq_length=self.cfg.data.get('min_seq_length', 1),
Expand All @@ -733,7 +733,7 @@ def setup_test_data(self, test_data_config=None):

def build_virtual_prompt_dataset(
self,
dataset_paths,
data,
batch_size=None,
max_seq_length=2048,
min_seq_length=1,
Expand All @@ -748,7 +748,7 @@ def build_virtual_prompt_dataset(
get_dataset_only=False,
):
dataset = GPTPromptLearningDataset(
dataset_paths=dataset_paths,
data=data,
tokenizer=self.tokenizer,
virtual_prompt_source=self.virtual_prompt_source,
task_templates=self.task_templates,
Expand Down Expand Up @@ -885,9 +885,14 @@ def dummy():

max_input_length = self.frozen_model.cfg.encoder_seq_length - length_params["max_length"]

dataset_paths = [path["data_path"] for path in inputs]
# input dicts are either dataset paths or already loaded example dicts
if "taskname" not in inputs[0].keys():
data = [path["data_path"] for path in inputs]
else:
data = inputs

dataset = self.build_virtual_prompt_dataset(
dataset_paths=dataset_paths,
data=data,
max_seq_length=max_input_length,
min_seq_length=self.cfg.data.get('min_seq_length', 1),
add_bos=sampling_params["add_BOS"],
Expand Down
2 changes: 1 addition & 1 deletion tests/collections/nlp/test_prompt_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_prompt_tuning_dataset(
dataset_path, tokenizer, virtual_prompt_source, task_templates, pseudo_tokens,
):
dataset = GPTPromptLearningDataset(
dataset_paths=[dataset_path],
data=[dataset_path],
tokenizer=tokenizer,
virtual_prompt_source=virtual_prompt_source,
task_templates=task_templates,
Expand Down