Skip to content

Commit

Permalink
Further changes to fix fine-tuning
Browse files Browse the repository at this point in the history
Signed-off-by: SeanNaren <[email protected]>
  • Loading branch information
SeanNaren committed Mar 22, 2023
1 parent 22527b1 commit f2451f8
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
22 changes: 20 additions & 2 deletions nemo/collections/nlp/data/glue_benchmark/glue_benchmark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def __init__(
max_seq_length_decoder: int = 128,
use_cache: bool = True,
prefix_override: str = None,
pad_to_max_length: bool = True,
):
"""
Processes GLUE datasets
Expand All @@ -392,10 +393,12 @@ def __init__(
max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
use_cache: whether to use data cache
prefix_override: if you want to override default prompt for this task specify this via a string.
pad_to_max_length: If true, pad to the maximum length.
"""
super().__init__(file_name, task_name, tokenizer, max_seq_length, use_cache, compute_features=False)
self.max_seq_length = max_seq_length
self.max_seq_length_decoder = max_seq_length_decoder
self.pad_to_max_length = pad_to_max_length
self.processor = processors[self.task_name]()
self.prefix_override = prefix_override
self.features = self.convert_examples_to_features()
Expand All @@ -412,9 +415,16 @@ def collate_fn(self, batch):
dec_input = [item['text_dec'] for item in batch]
labels = [item['labels'] for item in batch]

max_dec_input_length = max([len(item) for item in dec_input]) if dec_input else 0
max_enc_query_length = max([len(item) for item in enc_query]) if enc_query else 0
max_dec_input_length = max([len(item) for item in dec_input]) if dec_input else 0
max_label_length = max([len(item) for item in labels]) if labels else 0
if self.pad_to_max_length:
assert max_enc_query_length <= self.max_seq_length
assert max_dec_input_length <= self.max_seq_length_decoder
assert max_label_length <= self.max_seq_length_decoder
max_enc_query_length = self.max_seq_length
max_dec_input_length = self.max_seq_length_decoder
max_label_length = self.max_seq_length_decoder

loss_mask = [([1] * (len(item))) + ([0] * (max_label_length - len(item))) for item in labels]
enc_query = [item + [self.tokenizer.pad_id] * (max_enc_query_length - len(item)) for item in enc_query]
Expand Down Expand Up @@ -488,10 +498,18 @@ def __init__(
use_cache: bool = True,
prefix_override: str = None,
lang_list: List[str] = None,
pad_to_max_length: bool = True,
):
self.lang_list = set(lang_list)
super().__init__(
file_name, task_name, tokenizer, max_seq_length, max_seq_length_decoder, use_cache, prefix_override
file_name,
task_name,
tokenizer,
max_seq_length,
max_seq_length_decoder,
use_cache,
prefix_override,
pad_to_max_length,
)
if len(lang_list) <= 0 or lang_list is None:
raise ValueError(f"Found an empty or None lang_list for {self.task_name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,11 @@ def build_train_valid_test_datasets(self, stage):
self._train_ds = self._build_dataset(self.cfg.data.train_ds, check_implict_grad_acc=False)
logging.info(f'Length of train dataset: {len(self._train_ds)}')
logging.info(f'Finished building GLUE/XNLI datasets.')

@property
def max_decoder_seq_length(self) -> int:
return self._train_ds.max_seq_length_decoder

@property
def max_encoder_seq_length(self) -> int:
return self._train_ds.max_seq_length
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def training_step(self, dataloader_iter, batch_idx):
# we zero grads here because we also call backward in the megatron fwd/bwd functions
self._optimizer.zero_grad()

tensor_shape = [self.cfg.seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size]
tensor_shape = [self.max_encoder_seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size]

fwd_bwd_function = get_forward_backward_func()

Expand All @@ -317,7 +317,7 @@ def training_step(self, dataloader_iter, batch_idx):
num_microbatches=get_num_microbatches(),
forward_only=False,
tensor_shape=tensor_shape,
decoder_seq_length=self.decoder_seq_length,
decoder_seq_length=self.max_decoder_seq_length,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
)
Expand Down Expand Up @@ -383,9 +383,13 @@ def training_step(self, dataloader_iter, batch_idx):
return loss_mean

@property
def decoder_seq_length(self) -> int:
def max_decoder_seq_length(self) -> int:
return self._cfg.data.seq_length_dec

@property
def max_encoder_seq_length(self) -> int:
return self.cfg.seq_length

def backward(self, *args, **kwargs):
""" LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from apex.
Expand Down Expand Up @@ -612,7 +616,7 @@ def validation_step(self, dataloader_iter, batch_idx):
"""
return_values - if given, returns a dictionary with given keys and corresponding values
"""
tensor_shape = [self.cfg.seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size]
tensor_shape = [self.max_encoder_seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size]
fwd_bwd_func = get_forward_backward_func()
losses_reduced_per_micro_batch = fwd_bwd_func(
forward_step_func=self.get_forward_output_and_loss_func(),
Expand All @@ -621,7 +625,7 @@ def validation_step(self, dataloader_iter, batch_idx):
forward_only=True,
tensor_shape=tensor_shape,
num_microbatches=get_num_microbatches(),
decoder_seq_length=self.decoder_seq_length,
decoder_seq_length=self.max_decoder_seq_length,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
)
Expand Down

0 comments on commit f2451f8

Please sign in to comment.