diff --git a/Jenkinsfile b/Jenkinsfile index 42d60bb2809d..4d14b18fad62 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -3298,6 +3298,74 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings" } } + stage('L2: Megatron GPT Finetuning PP=2') { + when { + anyOf { + branch 'main' + changeRequest target: 'main' + } + } + failFast true + steps { + sh "python examples/nlp/language_modeling/tuning/megatron_gpt_sft.py \ + trainer.devices=2 \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + +trainer.limit_val_batches=2 \ + trainer.max_steps=3 \ + trainer.precision=16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_sft_results \ + model.pipeline_model_parallel_size=2 \ + model.tensor_model_parallel_size=1 \ + model.restore_from_path=/home/TestData/nlp/megatron_gpt/PP2/gpt_pp2_tp1.nemo \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.data.train_ds.micro_batch_size=1 \ + model.data.train_ds.global_batch_size=4 \ + model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \ + model.data.train_ds.concat_sampling_probabilities=[0.3,0.7] \ + model.data.train_ds.num_workers=0 \ + model.data.test_ds.micro_batch_size=1 \ + model.data.test_ds.global_batch_size=4 \ + model.data.test_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \ + model.data.test_ds.names=[quarel,trec] \ + model.data.validation_ds.micro_batch_size=1 \ + model.data.validation_ds.global_batch_size=4 \ + model.data.validation_ds.num_workers=0 \ + model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \ + model.data.validation_ds.names=[quarel,trec]" + sh "python examples/nlp/language_modeling/tuning/megatron_gpt_sft.py \ + trainer.devices=2 \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + +trainer.limit_val_batches=2 \ + trainer.max_steps=3 \ + trainer.precision=16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_sft_results \ + model.pipeline_model_parallel_size=2 \ + model.tensor_model_parallel_size=1 \ + model.restore_from_path=/home/TestData/nlp/megatron_gpt/PP2/gpt_pp2_tp1.nemo \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.data.train_ds.micro_batch_size=1 \ + model.data.train_ds.global_batch_size=4 \ + model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \ + model.data.train_ds.concat_sampling_probabilities=[0.3,0.7] \ + model.data.train_ds.num_workers=0 \ + model.data.test_ds.micro_batch_size=1 \ + model.data.test_ds.global_batch_size=4 \ + model.data.test_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \ + model.data.test_ds.names=[quarel,trec] \ + model.data.validation_ds.micro_batch_size=1 \ + model.data.validation_ds.global_batch_size=4 \ + model.data.validation_ds.num_workers=0 \ + model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \ + model.data.validation_ds.names=[quarel,trec]" + sh "rm -rf examples/nlp/language_modeling/gpt_sft_results" + } + } stage('L2: Megatron GPT Eval') { when { anyOf { diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml new file mode 100644 index 000000000000..12db9133104a --- /dev/null +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml @@ -0,0 +1,164 @@ +name: megatron_gpt_sft + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + replace_sampler_ddp: False + max_epochs: 9999 + max_steps: 20000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 200 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 2 + mode: max + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below, + filename: 'megatron_gpt_sft--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}' + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + answer_only_loss: False # not used right now + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + data: + train_ds: + # Example of how to specify paths to multiple datasets + # file_names: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + file_names: ??? # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 4 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Example of how to specify concat_sampling_probabilities + # concat_sampling_probabilities: + # - 0.5 + # - 0.25 + # - 0.25 + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'input' + label_key: 'output' + add_eos: True + add_sep: False + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: null # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + + validation_ds: + file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 4 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + context_key: 'input' + label_key: 'output' + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + test_ds: + file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 4 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + context_key: 'input' + label_key: 'output' + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + optim: + name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work. + lr: 3e-5 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py new file mode 100644 index 000000000000..b2b8786df8c1 --- /dev/null +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py @@ -0,0 +1,189 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +from omegaconf.omegaconf import OmegaConf, open_dict +from pytorch_lightning import Trainer +from pytorch_lightning.plugins.environments import TorchElasticEnvironment +from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel +from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel +from nemo.collections.nlp.parts.nlp_overrides import ( + GradScaler, + MegatronHalfPrecisionPlugin, + NLPDDPStrategy, + NLPSaveRestoreConnector, + PipelineMixedPrecisionPlugin, +) +from nemo.core.config import hydra_runner +from nemo.utils import AppState, logging +from nemo.utils.exp_manager import exp_manager +from nemo.utils.model_utils import inject_model_parallel_rank + + +def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): + """ + This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg). + The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. + """ + OmegaConf.set_struct(gpt_cfg, True) + OmegaConf.resolve(cfg) + with open_dict(gpt_cfg): + gpt_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) + gpt_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size + gpt_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size + gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False) + gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None) + gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None) + gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None) + gpt_cfg.data = cfg.model.data + gpt_cfg.optim = cfg.model.optim + gpt_cfg.precision = cfg.trainer.precision + gpt_cfg.answer_only_loss = cfg.model.answer_only_loss + gpt_cfg.restore_from_path = cfg.model.restore_from_path + gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint + gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end + gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view + gpt_cfg.hidden_dropout = cfg.model.get('hidden_dropout', 0.0) + gpt_cfg.attention_dropout = cfg.model.get('attention_dropout', 0.0) + gpt_cfg.ffn_dropout = cfg.model.ffn_dropout + + # This is needed when modifying a hparam file directly to load `.ckpt` files. + # This is not needed to modify the cfg in `.nemo` files. + if add_cfg_to_tree: + OmegaConf.resolve(gpt_cfg) + gpt_cfg.cfg = gpt_cfg + + return gpt_cfg + + +def load_from_nemo(cls, cfg, trainer, gpt_cfg, modify_confg_fn): + gpt_cfg = modify_confg_fn(gpt_cfg, cfg, add_cfg_to_tree=False) + save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.model.restore_from_path): + save_restore_connector.model_extracted_dir = cfg.model.restore_from_path + model = cls.restore_from( + restore_path=cfg.model.restore_from_path, + trainer=trainer, + override_config_path=gpt_cfg, + save_restore_connector=save_restore_connector, + ) + return model + + +def load_from_checkpoint_dir(cls, cfg, trainer, modify_confg_fn): + app_state = AppState() + if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1: + app_state.model_parallel_size = cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size + app_state.tensor_model_parallel_size = cfg.model.tensor_model_parallel_size + app_state.pipeline_model_parallel_size = cfg.model.pipeline_model_parallel_size + ( + app_state.tensor_model_parallel_rank, + app_state.pipeline_model_parallel_rank, + app_state.model_parallel_size, + app_state.data_parallel_size, + app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, + ) = fake_initialize_model_parallel( + world_size=app_state.model_parallel_size, + rank=trainer.global_rank, + tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size, + pipeline_model_parallel_size_=cfg.model.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank_=cfg.model.pipeline_model_parallel_split_rank, + ) + checkpoint_path = inject_model_parallel_rank( + os.path.join(cfg.model.pretrained_checkpoint.checkpoint_dir, cfg.model.pretrained_checkpoint.checkpoint_name) + ) + hparams_file = OmegaConf.load(cfg.model.pretrained_checkpoint.hparams_file) + gpt_cfg = modify_confg_fn(hparams_file.cfg, cfg, add_cfg_to_tree=True) + with tempfile.NamedTemporaryFile(suffix='.yaml') as f: + OmegaConf.save(config=gpt_cfg, f=f.name) + model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name,) + return model + + +def validate_checkpoint_loading_args(cfg): + if cfg.checkpoint_dir is None or not os.path.isdir(cfg.checkpoint_dir): + raise ValueError(f'Checkpoint directory {cfg.checkpoint_dir} does not exist or is not a directory.') + if cfg.checkpoint_name is None: + raise ValueError(f'Checkpoint name {cfg.checkpoint_name} is not valid.') + if cfg.hparams_file is None or not os.path.isfile(cfg.hparams_file): + raise ValueError(f'Hparams file {cfg.hparams_file} does not exist or is not a file.') + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_sft") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) + with_distributed_adam = cfg.model.optim.get('name', 'fused_adam') == 'distributed_fused_adam' + plugins = [] + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, + find_unused_parameters=False, + ) + if cfg.trainer.precision in [16, 'bf16']: + scaler = None + if cfg.trainer.precision == 16: + scaler = GradScaler( + init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + growth_interval=cfg.model.get('native_amp_growth_interval', 1000), + hysteresis=cfg.model.get('hysteresis', 2), + ) + if megatron_amp_o2 and not with_distributed_adam: + plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) + else: + plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) + + if cfg.get('cluster_type', None) == 'BCP': + plugins.append(TorchElasticEnvironment()) + + trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) + + exp_manager(trainer, cfg.exp_manager) + + # update resume from checkpoint found by exp_manager + if cfg.model.resume_from_checkpoint is not None: + resume_from_checkpoint = cfg.model.resume_from_checkpoint + else: + resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path + logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}') + + trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint) + + if cfg.model.restore_from_path: + save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.model.restore_from_path): + save_restore_connector.model_extracted_dir = cfg.model.restore_from_path + gpt_cfg = MegatronGPTSFTModel.restore_from( + restore_path=cfg.model.restore_from_path, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, + ) + model = load_from_nemo(MegatronGPTSFTModel, cfg, trainer, gpt_cfg, modify_confg_fn=_modify_config) + else: + validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint) + model = load_from_checkpoint_dir(MegatronGPTSFTModel, cfg, trainer, modify_confg_fn=_modify_config) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/common/metrics/metric_string_to_torchmetric.py b/nemo/collections/common/metrics/metric_string_to_torchmetric.py index e83d88057143..2d1e094a0d8b 100644 --- a/nemo/collections/common/metrics/metric_string_to_torchmetric.py +++ b/nemo/collections/common/metrics/metric_string_to_torchmetric.py @@ -13,6 +13,7 @@ # limitations under the License. from torchmetrics import Accuracy, AveragePrecision, F1Score, MatthewsCorrCoef, PearsonCorrCoef, SpearmanCorrCoef +from torchmetrics.text.rouge import ROUGEScore from nemo.collections.common.metrics.classification_accuracy import ExactStringMatchMetric @@ -28,4 +29,5 @@ 'spearman_corr_coef': SpearmanCorrCoef, 'matthews_corr_coef': MatthewsCorrCoef, 'exact_string_match': ExactStringMatchMetric, + 'rouge': ROUGEScore, } diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py new file mode 100644 index 000000000000..d05bb7191ae1 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -0,0 +1,281 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import get_samples_mapping +from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import JSONLMemMapDataset +from nemo.core.classes import Dataset + +__all__ = ['GPTSFTDataset'] + + +class GPTSFTDataset(Dataset): + def __init__( + self, + file_path: str, + tokenizer: TokenizerSpec, + max_seq_length: int = 1024, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + add_sep: bool = False, + sep_id: int = None, + max_num_samples: int = None, + seed: int = 1234, + context_key: str = "text", + label_key: str = "answer", + separate_prompt_and_response_with_newline: bool = False, + answer_only_loss: bool = True, + truncation_field: str = "answer", + pad_to_max_length: bool = True, + index_mapping_dir: str = None, + prompt_template: str = None, + ): + """ + file_path: Path to a JSONL GPT supervised fine-tuning dataset. Data is formatted as multiple JSON lines with each line formatted as follows. {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). + max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. + min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. + add_bos (bool): Whether to add a beginning of sentence token to each data example + add_eos (bool): Whether to add an end of sentence token to each data example + add_sep (bool): Whether to add a separation token to each data example (goes between prompt and answer) + tokens_to_generate (int): (inference only) Number of tokens to generate during inference + seed: Random seed for data shuffling. + max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. + seed: int = 1234, + context_key: Key to use for the context in your JSONL file + label_key: Key to use for the label in your JSONL file + separate_prompt_and_response_with_newline: Adds a newline between prompt and response. + answer_only_loss: If True, will compute the loss only on the answer part of the input. If False, will compute the loss on the entire input. + truncation_field: Field to use for truncation. (Options: "answer", "context"). Field to be used for truncation if the combined length exceeds the max sequence length. + pad_to_max_length: Whether to pad the input to the max sequence length. If False, will pad to the max length of the current batch. + index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset. + prompt_template: Prompt template to inject via an fstring. Formatted like Q: {input}\n\nA: {output} + """ + self.tokenizer = tokenizer + self.file_path = file_path + self.max_seq_length = max_seq_length + self.min_seq_length = min_seq_length + self.add_bos = add_bos + self.add_eos = add_eos + self.add_sep = add_sep + self.sep_id = sep_id + self.max_num_samples = max_num_samples + self.seed = seed + self.context_key = context_key + self.label_key = label_key + self.separate_prompt_and_response_with_newline = separate_prompt_and_response_with_newline + self.answer_only_loss = answer_only_loss + self.truncation_field = truncation_field + self.pad_to_max_length = pad_to_max_length + self.index_mapping_dir = index_mapping_dir + self.prompt_template = prompt_template + assert self.truncation_field in ["answer", "context"] + + self.indexed_dataset = JSONLMemMapDataset(dataset_paths=[file_path], tokenizer=None, header_lines=0) + + # Will be None after this call if `max_num_samples` is None + self._build_samples_mapping() + + def _build_samples_mapping(self): + if self.max_num_samples is not None: + self.samples_mapping = get_samples_mapping( + indexed_dataset=self.indexed_dataset, + data_prefix=self.file_path, + num_epochs=None, + max_num_samples=self.max_num_samples, + max_seq_length=self.max_seq_length - 2, + short_seq_prob=0, + seed=self.seed, + name=self.file_path.split('/')[-1], + binary_head=False, + index_mapping_dir=self.index_mapping_dir, + ) + else: + self.samples_mapping = None + + def __len__(self): + if self.max_num_samples is None: + return len(self.indexed_dataset) + else: + return len(self.samples_mapping) + + def __getitem__(self, idx): + if isinstance(idx, np.int64): + idx = idx.item() + + if self.samples_mapping is not None: + assert idx < len(self.samples_mapping) + idx, _, _ = self.samples_mapping[idx] + if isinstance(idx, np.uint32): + idx = idx.item() + + assert idx < len(self.indexed_dataset) + example = self.indexed_dataset[idx] + return self._process_example(example) + + def _process_example(self, example): + """ + Create an example by concatenating text and answer. + Truncation is carried out when needed, but it is performed only on the prompt side. + BOS, EOS, and SEP, are added if specified. + """ + context = example[self.context_key] + output = example[self.label_key] + + if self.prompt_template is not None: + assert '{input}' in self.prompt_template + assert '{output}' in self.prompt_template + # Make sure that '{output}' always occurs at the end of the prompt template string + assert self.prompt_template.index('{output}') == len(self.prompt_template) - len('{output}') + # Get the context by replacing only the input + original_context = context + context = self.prompt_template.replace('{input}', context).replace('{output}', '').strip(' ') + # Replace the input and output placeholders with the actual input and output + text = self.prompt_template.replace('{input}', original_context).replace('{output}', output) + + if self.separate_prompt_and_response_with_newline and self.prompt_template is None: + text = context + '\n' + output + elif not self.separate_prompt_and_response_with_newline and self.prompt_template is None: + text = context + ' ' + output + + tokenized_text = self.tokenizer.text_to_ids(text) + context_ids = self.tokenizer.text_to_ids(context) + answer_ids = tokenized_text[len(context_ids) :] + total_ids = len(context_ids) + len(answer_ids) + if self.add_bos: + total_ids += 1 + if self.add_sep: + total_ids += 1 + if self.add_eos: + total_ids += 1 + + # If the total number of token is greater than the max, we will try to truncate the answer + if total_ids > self.max_seq_length: + truncation_length = total_ids - self.max_seq_length + if self.truncation_field == "answer": + answer_ids = answer_ids[: -min(truncation_length, len(answer_ids))] + elif self.truncation_field == "context": + context_ids = context_ids[: -min(truncation_length, len(context_ids))] + + if len(context_ids) > self.max_seq_length: + context_ids = context_ids[: self.max_seq_length] + + assert len(context_ids) <= self.max_seq_length + input_ids = context_ids + + answer_start_idx = len(input_ids) + # Adds sep token between text/prompt and answer + if self.add_sep: + input_ids = input_ids + [self.sep_id] + answer_start_idx += 1 + + input_ids = input_ids + answer_ids + + if self.add_bos: + input_ids = [self.tokenizer.bos_id] + input_ids + answer_start_idx += 1 + if self.add_eos: + input_ids = input_ids + [self.tokenizer.eos_id] + + if len(input_ids) < self.min_seq_length or len(input_ids) > self.max_seq_length: + input_ids = input_ids[: self.max_seq_length] + + processed_example = { + 'input_ids': input_ids, + 'answer_start_idx': answer_start_idx, + 'context_ids': context_ids, + 'context_length': len(context_ids), + } + + return processed_example + + def _maybe_cast_to_list(self, x): + if isinstance(x, np.ndarray): + return [item.tolist() for item in x] + return x + + def _round_to_nearest(self, n, m): + return (n + m - 1) // m * m + + def _collate_item(self, item, max_length, pad_id): + item = self._maybe_cast_to_list(item) + # max_length = max([len(x) for x in item]) if item else 0 + # here [0] should be tokenizer.pad_id + item = [x + [pad_id] * (max_length - len(x)) for x in item] + return item + + def _build_loss_mask(self, processed_example): + """ Pad input_ids in batch to max batch length while building loss mask """ + input_ids = processed_example['input_ids'] + answer_start_idx = processed_example['answer_start_idx'] + if self.answer_only_loss: + loss_mask = [float(idx > answer_start_idx) for idx in range(len(input_ids))] + else: + loss_mask = [1.0] * len(input_ids) + + return loss_mask + + @torch.no_grad() + def _create_attention_mask(self, max_length): + """Create `attention_mask`. + Args: + input_ids: A 1D tensor that holds the indices of tokens. + """ + # seq_length = len(input_ids) + # `attention_mask` has the shape of [1, seq_length, seq_length] + attention_mask = torch.tril(torch.ones((max_length, max_length))).unsqueeze(0) + attention_mask = attention_mask < 0.5 + return attention_mask + + def collate_fn(self, batch): + input_ids = [item['input_ids'][:-1] for item in batch] + labels = [item['input_ids'][1:] for item in batch] + contexts = [item['context_ids'] for item in batch] + context_lengths = torch.LongTensor([item['context_length'] for item in batch]) + loss_mask = [self._build_loss_mask(item)[1:] for item in batch] + + max_length = max([len(x) for x in input_ids]) + # increase max length to nearest multiple of 4 or 8 + if self.pad_to_max_length: + max_length = self.max_seq_length + else: + max_length = min(self.max_seq_length, self._round_to_nearest(max_length, 8)) + assert max_length <= self.max_seq_length + + attention_mask = [self._create_attention_mask(max_length) for _ in batch] + attention_mask = torch.stack(attention_mask) + position_ids = [list(range(max_length)) for _ in batch] + position_ids = torch.LongTensor(position_ids) + input_ids = torch.LongTensor( + self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id) + ) + labels = torch.LongTensor(self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id)) + loss_mask = torch.LongTensor(self._collate_item(loss_mask, max_length=max_length, pad_id=0)) + contexts = torch.LongTensor(self._collate_item(contexts, max_length=max_length, pad_id=self.tokenizer.eos_id)) + + processed_batch = { + 'tokens': input_ids, + 'labels': labels, + 'attention_mask': attention_mask, + 'loss_mask': loss_mask, + 'position_ids': position_ids, + 'contexts': contexts, + 'context_lengths': context_lengths, + } + + return processed_batch diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 2d69817954aa..19dc724c4043 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -706,6 +706,8 @@ def validation_epoch_end(self, outputs): self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) + return averaged_loss + def test_step(self, batch, batch_idx): return self.validation_step(batch, batch_idx) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py new file mode 100644 index 000000000000..37df8bbd4a57 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -0,0 +1,663 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import torch +from omegaconf import DictConfig, ListConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.common.metrics import MetricStringToTorchMetric +from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( + get_datasets_weights_and_num_samples, +) +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset +from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingSampler +from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.text_generation_utils import LengthParam, SamplingParam, megatron_gpt_generate +from nemo.utils import AppState, logging + +try: + from apex.transformer import parallel_state + from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +__all__ = ['MegatronGPTSFTModel'] + + +class MegatronGPTSFTModel(MegatronGPTModel): + """ + Megatron GPT Supervised Fine-Tuning + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + if not HAVE_APEX: + raise ImportError( + "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + super().__init__(cfg, trainer=trainer) + self.sep_id = cfg.get('sep_id', 49704) + self.val_metric, self.val_metric_name = self.setup_metric(self.cfg.data.validation_ds) + self.val_metric = torch.nn.ModuleList(self.val_metric) if self.val_metric is not None else None + if hasattr(self.cfg.data, "test_ds"): + self.test_metric, self.test_metric_name = self.setup_metric(self.cfg.data.test_ds) + self.test_metric = torch.nn.ModuleList(self.test_metric) if self.test_metric is not None else None + + if self.cfg.get('megatron_amp_O2', False): + base_module = self.model.module + else: + base_module = self.model + + self.original_checkpointing_granularity = base_module.language_model.encoder.activations_checkpoint_granularity + self.original_checkpointing_num_layers = base_module.language_model.encoder.activations_checkpoint_num_layers + self.original_checkpointing_method = base_module.language_model.encoder.activations_checkpoint_method + + def setup_metric(self, data_cfg): + metric_name = "exact_string_match" + if not hasattr(data_cfg, "metric"): + metric = MetricStringToTorchMetric["exact_string_match"] + else: + if not hasattr(data_cfg.metric, "name"): + raise ValueError("Metric name is not provided in the metric config.") + if data_cfg.metric.name == "loss": + return None, "loss" + if data_cfg.metric.name not in MetricStringToTorchMetric: + raise KeyError( + f"{data_cfg.metric.name} is not supported. List of supported metrics: {MetricStringToTorchMetric.keys()}" + ) + if data_cfg.metric.name in self._metrics_require_string2category_map: + if data_cfg.metric.average is None: + raise ValueError( + f"{data_cfg.metric.name} requires specifying whether you want to compute a micro or macro average. Found None." + ) + if ( + data_cfg.metric.get('labels_are_strings', False) + and data_cfg.metric.name in self._metrics_require_string2category_map + ): + if data_cfg.metric.num_classes is None: + raise ValueError( + "Number of classes is not provided in the metric section within the data config. " + f"Please provide the number of classes in the data config to use the {data_cfg.metric.name} metric." + ) + if data_cfg.metric.get('class_labels', None) is None or not isinstance( + data_cfg.metric.get('class_labels', None), ListConfig + ): + raise ValueError( + "Class labels are not provided properly in the metric section witnin the data config. " + f"Please provide the class labels as a list of strings in the data config to use the {data_cfg.metric.name} metric." + ) + if len(data_cfg.metric.get('class_labels', None)) != data_cfg.metric.num_classes: + raise ValueError( + f"Number of class labels {len(data_cfg.metric.get('class_labels', None))} does not match `num_classes` : {data_cfg.metric.num_classes}" + ) + + metric_name = data_cfg.metric.name + metric = MetricStringToTorchMetric[metric_name] + + if isinstance(data_cfg.file_names, ListConfig): + if 'rouge' not in data_cfg.metric.name: + metric = [ + metric(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes) + for _ in range(len(data_cfg.file_names)) + ] + else: + metric = [metric() for _ in range(len(data_cfg.file_names))] + else: + if 'rouge' not in data_cfg.metric.name: + metric = [metric(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes)] + else: + metric = [metric()] + + return metric, metric_name + + @property + def _metrics_require_string2category_map(self): + return set(["f1", "accuracy", "average_precision"]) + + def setup(self, stage=None): + # NOTE: super().__init__ will try and setup train/val/test datasets, but we sidestep this using a if self._train_ds is not None condition + # We then set things up for real only once setup() of this class is called. + resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path + if resume_checkpoint_path: + init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path) + else: + init_consumed_samples = 0 + self.init_consumed_samples = init_consumed_samples + + if stage == 'predict': + return + + # If the user wants to manually override train and validation dataloaders before calling `.fit()` + if self._train_dl is not None and self._validation_dl is not None: + return + self.build_train_valid_test_datasets(stage=stage) + if hasattr(self, '_train_ds'): + self.setup_training_dataloader() + if hasattr(self, '_validation_ds'): + self._validation_dl = self.setup_eval_dataloader(self._validation_ds, self.cfg.data.validation_ds) + if hasattr(self.cfg.data, 'test_ds'): + self._test_dl = self.setup_eval_dataloader(self._test_ds, self.cfg.data.test_ds) + + # when using pipeline model parallel the final stage need to initialize word embeddings + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + if isinstance(self.model, list): + for i, module in enumerate(self.model): + parallel_state.set_virtual_pipeline_model_parallel_rank(i) + module.sync_initial_word_embeddings() + parallel_state.set_virtual_pipeline_model_parallel_rank(0) + else: + self.model.sync_initial_word_embeddings() + + if self.cfg.get('transformer_engine', False): + self.setup_transformer_engine_tp_groups() + + def _build_dataset(self, data_cfg, is_train=True): + datasets = [] + # Determine if we are using a single dataset or a list of datasets. + is_list_config = isinstance(data_cfg.file_names, ListConfig) + if not is_list_config: + raise ValueError(f"SFT train/validation datasets must be provided as a list of individual JSONL files.") + + if is_train: + # Construct the data prefix list for `get_datasets_weights_and_num_samples()` + # that is of the format [weight1,file_name1,weight2,file_name2,...] + if data_cfg.concat_sampling_probabilities is None or not isinstance( + data_cfg.concat_sampling_probabilities, ListConfig + ): + raise ValueError( + ( + f"concat_sampling_probabilities must be a ListConfig with the same number of files in file_names." + f"Found: {data_cfg.concat_sampling_probabilities}" + ) + ) + + if len(data_cfg.get('concat_sampling_probabilities', None)) != len(data_cfg.file_names): + raise ValueError( + ( + f"concat_sampling_probabilities must be of the same size as file_names.", + f"Provided size {len(data_cfg.concat_sampling_probabilities)}, number of datasets {len(data_cfg.file_names)}", + ) + ) + + data_prefix = [] + for weight, prefix in zip(data_cfg.concat_sampling_probabilities, data_cfg.file_names): + data_prefix.append(weight) + data_prefix.append(prefix) + + if self.trainer.max_steps is None or self.trainer.max_steps <= 0: + raise ValueError( + f'Trainer max_steps must be set to a positive integer. Found {self.trainer.max_steps}' + ) + num_train_samples = [self.trainer.max_steps * data_cfg.global_batch_size] + _, _, num_train_samples_per_dataset = get_datasets_weights_and_num_samples(data_prefix, num_train_samples) + num_train_samples_after_blend = sum([x[0] for x in num_train_samples_per_dataset]) + else: + num_train_samples_per_dataset = [[None]] * len(data_cfg.file_names) + + for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset): + dataset = GPTSFTDataset( + file_path=file_path, + tokenizer=self.tokenizer, + max_seq_length=data_cfg.max_seq_length, + min_seq_length=data_cfg.min_seq_length, + add_bos=data_cfg.get('add_bos', False), + add_eos=data_cfg.get('add_eos', True), + add_sep=data_cfg.get('add_sep', False), + sep_id=self.sep_id, + max_num_samples=num_samples[0], + seed=data_cfg.get('seed', 1234), + context_key=data_cfg.get('context_key', 'text'), + label_key=data_cfg.get('label_key', 'answer'), + separate_prompt_and_response_with_newline=data_cfg.get( + 'separate_prompt_and_response_with_newline', True + ), + answer_only_loss=self.cfg.get('answer_only_loss', True), + truncation_field=data_cfg.get('truncation_field', 'context'), + index_mapping_dir=data_cfg.get('index_mapping_dir', None), + prompt_template=data_cfg.get('prompt_template', None), + ) + datasets.append(dataset) + + if is_train: + dataset = BlendableDataset( + datasets=datasets, weights=data_cfg.concat_sampling_probabilities, size=num_train_samples_after_blend + ) + return dataset + else: + return datasets + + def _determine_log_key(self, data_config, dataloader_idx, metric_name, mode): + # Function that determines whether to log based on the user provided name of the dataset or the dataloader index. + base_key = f"{mode}_{metric_name}_" if metric_name is not None else f"{mode}_" + # If the user provided names for each validation/test dataset, use those. + if hasattr(data_config, "names") and data_config.names is not None: + # With only a single validation/test dataset, the name is not a list. + if not isinstance(data_config.names, ListConfig): + name = data_config.names + else: + name = data_config.names[dataloader_idx] + return base_key + name + else: + return base_key + f"dataloader{dataloader_idx}" + + def validation_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + return self.inference_step(dataloader_iter, batch_idx, 'validation', dataloader_idx) + + def validation_epoch_end(self, outputs): + _ = self.inference_epoch_end(outputs, 'validation', self.cfg.data.validation_ds) + + def test_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + return self.inference_step(dataloader_iter, batch_idx, 'test', dataloader_idx) + + def test_epoch_end(self, outputs): + _ = self.inference_epoch_end(outputs, 'test', self.cfg.data.test_ds) + + def inference_step(self, dataloader_iter, batch_idx, mode, dataloader_idx=0): + # Call parent validation step to get the loss. + loss = super().validation_step(dataloader_iter, batch_idx) + return { + 'loss': loss, + 'preds': None, + 'labels': None, + 'inputs': None, + } + # TODO (sandeepsub): Figure out the subsequent decode bits. + length_params: LengthParam = { + "min_length": 0, + "max_length": batch['tokens'].size(1) - batch['context_lengths'].max(), + } + sampling_params: SamplingParam = { + "use_greedy": True, + "temperature": 1.0, + "top_k": 1, + "top_p": 0.94, + "repetition_penalty": 1.2, + "add_BOS": False, + "all_probs": False, + "compute_logprob": False, + } + result = megatron_gpt_generate( + model=self, + inputs=( + batch['tokens'].cuda(), + (batch['context_lengths'] - 1).cuda(), + ), # NOTE: We do -1 here to remove the space between context and response. + tokenizer=self.tokenizer, + sampling_params=sampling_params, + length_params=length_params, + check_sequence_parallel_and_checkpointing=False, # We need to skip these checks since we'll manually enbale and disable checkpointing between training and validation. + ) + + preds_text = [] + labels_text = [] + input_text = [] + for idx, item in enumerate(result['token_ids']): + pred = self.tokenizer.ids_to_text(item[batch['context_lengths'][idx] - 1 :]) + input = self.tokenizer.ids_to_text(item[: batch['context_lengths'][idx] - 1]) + label = self.tokenizer.ids_to_text(batch['tokens'][idx][batch['context_lengths'][idx] :].tolist()) + preds_text.append(pred.strip()) + labels_text.append(label.strip()) + input_text.append(input.strip()) + + metric = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx] + assert len(preds_text) == len(labels_text) == len(input_text) + for _, (pred, label) in enumerate(zip(preds_text, labels_text)): + # To compute metrics like pearson or spearman correlation, we need to cast the predicted string and labels to floats. + pred, label = self.cast_for_metric( + pred=pred.strip(), + label=label.strip(), + metric_name=self.val_metric_name if mode == 'validation' else self.test_metric_name, + class_labels=self.cfg.data.validation_ds.metric.get('class_labels', None) + if mode == 'validation' + else self.cfg.data.test_ds.metric.get('class_labels', None), + labels_are_strings=self.cfg.data.validation_ds.metric.get('labels_are_strings', False) + if mode == 'validation' + else self.cfg.data.test_ds.metric.get('labels_are_strings', False), + ) + _ = metric(pred, label) + + return { + 'loss': loss, + 'preds': preds_text, + 'labels': labels_text, + 'inputs': input_text, + } + + def inference_epoch_end(self, outputs, mode, data_cfg): + # Parent class will handle logging of the loss. + if not outputs: + return + + if isinstance(outputs[0], dict): + outputs = [outputs] + + averaged_loss = [] + averaged_metric = [] + metric_name = self.val_metric_name if mode == 'validation' else self.test_metric_name + # Log metrics for each provided validation/test dataset. + for dataloader_idx, output in enumerate(outputs): + loss = super().validation_epoch_end([x['loss'] for x in output]) + # Determine the key used to log the loss based on the user provided name of the dataset or the dataloader index. + loss_log_key = self._determine_log_key(data_cfg, dataloader_idx, "loss", mode) + self.log(loss_log_key, loss) + averaged_loss.append(loss) + + # Skip the rest of this loop if the user wants to monitor the loss only. + if self.val_metric is None: + continue + # Determine the key used to log the eval metric based on the user provided name of the dataset or the dataloader index. + metric_log_key = self._determine_log_key(data_cfg, dataloader_idx, metric_name, mode) + metric_object = ( + self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx] + ) + metric = metric_object.compute() + # Handle logging of GLUE/XNLI separately here. XNLI has a separate metric per language. + if isinstance(metric, dict): + if metric_name == 'rouge': + metric = metric['rougeL_fmeasure'] + else: + metric = metric['acc'] + torch.distributed.all_reduce( + metric, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_data_parallel_group() + ) + metric = metric / parallel_state.get_data_parallel_world_size() + self.log(metric_log_key, metric) + logging.info(f"{mode} {metric_name}: {metric}") + + metric_object.reset() + + averaged_metric.append(metric) + + # Write predictions, labels, and inputs to a file for each validation/test dataset. + if data_cfg.get("write_predictions_to_file", False): + + # Check if the user provided a prefix path to the file(s) they want to write. + if not hasattr(data_cfg, "output_file_path_prefix") or data_cfg.output_file_path_prefix is None: + raise ValueError( + f"Cannot write predictions to file when output_file_path_prefix is not set or present in the yaml config file." + ) + + # Gather the outputs object from all data parallel ranks since we are using the DistributedSampler which splits data across DDP ranks. + gathered_outputs = [None for _ in range(parallel_state.get_data_parallel_world_size())] + torch.distributed.all_gather_object( + gathered_outputs, + [{'preds': x['preds'], 'labels': x['labels'], 'inputs': x['inputs'],} for x in output], + group=parallel_state.get_data_parallel_group(), + ) + + # Figure out what the suffix of the file should be. + filename_log_key = self._determine_log_key(data_cfg, dataloader_idx, None, mode) + + # Keep a set of ground truths and inputs to write deduplicated predictions. Distributed Sampler may duplicate examples. + gt_inp_set = set() + deduplicated_outputs = { + 'preds': [], + 'labels': [], + 'inputs': [], + } + + # PTL models have a self.global_rank attribute and we want to write to disk only on global rank 0. + if self.global_rank == 0: + for rank in range(0, parallel_state.get_data_parallel_world_size()): + for batch in gathered_outputs[rank]: + for pred, label, input in zip(batch['preds'], batch['labels'], batch['inputs']): + gt_inp_set.add(input + label) + deduplicated_outputs['preds'].append(pred) + deduplicated_outputs['labels'].append(label) + deduplicated_outputs['inputs'].append(input) + self.write_predictions_to_file( + deduplicated_outputs, f"{data_cfg.output_file_path_prefix}_{filename_log_key}" + ) + torch.distributed.barrier() + + # Logging of the averaged metrics: + averaged_loss = sum(averaged_loss) / len(averaged_loss) + averaged_metric = sum(averaged_metric) / len(averaged_metric) if len(averaged_metric) > 1 else None + + # Handle case where metrics can be nan or inf. This can break checkpoint save/load. + if averaged_metric is not None and (torch.isinf(averaged_metric) or torch.isnan(averaged_metric)): + app_state = AppState() + monitor_mode = app_state.checkpoint_callback_params.mode + assert monitor_mode in ['min', 'max'] + averaged_metric = 0.0 if monitor_mode == 'max' else 1e5 + + if mode == 'validation': + self.log("validation_loss", averaged_loss) + if averaged_metric is not None: + self.log(f"validation_{self.val_metric_name}", averaged_metric) + elif mode == 'test': + self.log("test_loss", averaged_loss) + if averaged_metric is not None: + self.log(f"test_{self.test_metric_name}", averaged_metric) + + return averaged_loss, averaged_metric + + def write_predictions_to_file(self, outputs, output_file_path_prefix): + with open(output_file_path_prefix + "_inputs_preds_labels.jsonl", "w") as f_json: + assert len(outputs['inputs']) == len(outputs['preds']) == len(outputs['labels']) + for i, p, l in zip(outputs['inputs'], outputs['preds'], outputs['labels']): + f_json.write(json.dumps({'input': i, 'pred': p, 'label': l}) + '\n') + + def cast_for_metric(self, pred, label, metric_name, class_labels=None, labels_are_strings=False): + if metric_name == 'exact_string_match' or 'rouge' in metric_name: + return pred, label + pred = pred.replace(' ', '') + label = label.replace(' ', '') + + # Correlation metrics require casting to float. + if metric_name in ['pearson_corr_coef', 'spearman_corr_coef']: + # Text-to-text model predictions may not always be valid floating point numbers. + try: + pred = float(pred) + except ValueError: + pred = 0.0 + + try: + label = float(label) + except ValueError: + raise ValueError(f'Could not convert {label} to float.') + + pred = torch.FloatTensor([pred]).to(self.device) + label = torch.FloatTensor([label]).to(self.device) + + # Other metrics require casting to integers. + elif metric_name in self._metrics_require_string2category_map and not labels_are_strings: + # Text-to-text model predictions may not always be valid integers. + try: + pred = int(pred) + except ValueError: + pred = 0 + + try: + label = int(label) + except ValueError: + raise ValueError(f'Could not convert {label} to int.') + + pred = torch.LongTensor([pred]).to(self.device) + label = torch.LongTensor([label]).to(self.device) + + # If labels are strings, we need to convert them to indices for some metrics. + elif metric_name in self._metrics_require_string2category_map and labels_are_strings: + # Cast string labels to integers before computing the metric. + if pred not in class_labels: + pred = 0 # If the prediction is not in the class labels, use the first class label. + else: + pred = class_labels.index(pred) + if label not in class_labels: + raise ValueError(f"Ground truth labe; {label} is not in the class labels list : {class_labels}") + label = class_labels.index(label) + pred = torch.LongTensor([pred]).to(self.device) + label = torch.LongTensor([label]).to(self.device) + else: + raise ValueError(f'Metric {metric_name} not supported.') + + return pred, label + + # Override the parent batch reconfiguring logic. + def _reconfigure_and_process_inference_batch(self, batch): + global_batch_per_gpu = batch['tokens'].size(0) + # This should happen only on the last batch of the validation/test dataset with drop_last=False. + if global_batch_per_gpu != self.cfg.data.validation_ds.global_batch_size: + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), + micro_batch_size=global_batch_per_gpu, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + return batch + + def build_train_valid_test_datasets(self, stage): + if stage != 'test': + logging.info('Building GPT SFT validation datasets.') + # Wrap this in a list since the general finetuning parent class supports multi-validation. + self._validation_ds = self._build_dataset(self.cfg.data.validation_ds, is_train=False) + logging.info(f'Length of val dataset: {len(self._validation_ds[0])}') + + if stage != 'validate': + if hasattr(self.cfg.data, 'test_ds'): + logging.info('Building GPT SFT test datasets.') + # Wrap this in a list since the general finetuning parent class supports multi-validation. + self._test_ds = self._build_dataset(self.cfg.data.test_ds, is_train=False) + logging.info(f'Length of test dataset: {len(self._test_ds[0])}') + + if stage == 'validate' or stage == 'test': + return + logging.info('Building GPT SFT traing datasets.') + self._train_ds = self._build_dataset(self.cfg.data.train_ds) + logging.info(f'Length of train dataset: {len(self._train_ds)}') + + def build_data_loader(self, dataset, data_cfg, consumed_samples=0): + """Buld dataloader given an input dataset.""" + + logging.info(f'Building dataloader with consumed samples: {consumed_samples}') + if isinstance(dataset, BlendableDataset): + collate_fn = dataset.datasets[0].collate_fn + else: + collate_fn = dataset.collate_fn + + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=data_cfg.micro_batch_size, + global_batch_size=data_cfg.global_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + drop_last=True, + pad_samples_to_global_batch_size=False, + ) + return torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=data_cfg.num_workers, + pin_memory=data_cfg.pin_memory, + ) + + def setup_training_dataloader(self): + if hasattr(self, '_train_ds'): + consumed_samples = self.compute_consumed_samples(0) + self._train_dl = self.build_data_loader( + dataset=self._train_ds, data_cfg=self.cfg.data.train_ds, consumed_samples=consumed_samples, + ) + + def setup_eval_dataloader(self, datasets, data_cfg): + dataloaders = [] + for dataset in datasets: + eval_dl = self.build_data_loader(dataset=dataset, data_cfg=data_cfg, consumed_samples=0,) + dataloaders.append(eval_dl) + return dataloaders + + def _reset_activation_checkpointing_args(self): + if self.cfg.get('megatron_amp_O2', False): + base_module = self.model.module + else: + base_module = self.model + + base_module.language_model.encoder.activations_checkpoint_granularity = None + base_module.language_model.encoder.activations_checkpoint_method = None + base_module.language_model.encoder.activations_checkpoint_num_layers = None + + def _restore_activation_checkpointing_args(self): + if self.cfg.get('megatron_amp_O2', False): + base_module = self.model.module + else: + base_module = self.model + base_module.language_model.encoder.activations_checkpoint_granularity = self.original_checkpointing_granularity + base_module.language_model.encoder.activations_checkpoint_method = self.original_checkpointing_method + base_module.language_model.encoder.activations_checkpoint_num_layers = self.original_checkpointing_num_layers + + def on_validation_epoch_start(self): + self._reset_activation_checkpointing_args() + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.data.validation_ds.global_batch_size, + micro_batch_size=self.cfg.data.validation_ds.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + return super().on_validation_epoch_start() + + def on_test_epoch_start(self): + app_state = AppState() + self._reset_activation_checkpointing_args() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.data.test_ds.global_batch_size, + micro_batch_size=self.cfg.data.test_ds.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + return super().on_test_epoch_start() + + def on_test_epoch_end(self): + self.on_inference_epoch_end(self.cfg.data.test_ds) + return super().on_test_epoch_end() + + def on_validation_epoch_end(self): + self.on_inference_epoch_end(self.cfg.data.validation_ds) + return super().on_validation_epoch_end() + + def on_inference_epoch_end(self, ds): + app_state = AppState() + self._restore_activation_checkpointing_args() + if hasattr(self, "_train_ds"): + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.data.train_ds.global_batch_size, + micro_batch_size=self.cfg.data.train_ds.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + # When running `trainer.validate()`, the training dataset is not available. + else: + logging.warning('No training data found, reconfiguring microbatches based on validation batch sizes.') + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=ds.global_batch_size, + micro_batch_size=ds.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + def on_train_epoch_start(self) -> None: + # Same logic as validation epoch end, but this may be need if there is no validation sanity check to trigger validation_epoch_end() + self.on_validation_epoch_end() + return super().on_train_epoch_start() diff --git a/scripts/nlp_language_modeling/niv2/preprocess_niv2.py b/scripts/nlp_language_modeling/niv2/preprocess_niv2.py new file mode 100644 index 000000000000..073d6da8f32c --- /dev/null +++ b/scripts/nlp_language_modeling/niv2/preprocess_niv2.py @@ -0,0 +1,171 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re +from argparse import ArgumentParser +from multiprocessing import Pool + +from sacremoses import MosesDetokenizer + +from nemo.collections.common.tokenizers import AutoTokenizer + + +""" +This script converts the NaturalInstructions v2 dataset into individual JSONL files. + +Use instructions: + +1. Download the NaturalInstructions dataset by cloning it from allenai: + git clone https://github.com/allenai/natural-instructions. The raw data should be in the tasks folder. + +2. Run this script: + python preprocess_niv2.py \ + --niv2_dataset_path natural-instructions/tasks \ + --jsonl_output_path natural-instructions/train_tasks_default_jsonl \ + --splits_file_path natural-instructions/splits/default/train_tasks.txt + +3. The output will be in the jsonl_output_path directory. + +4. Each JSONL file is compatible with NeMo's T0JSONLMemMapDataset (https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/data/language_modeling/t0_dataset.py) +""" + + +def remove_newline_and_detokenize(x, detokenizer): + x = re.sub(r'\\n+', ' ', x) + x = re.sub(r'\n+', ' ', x) + x = re.sub(r'\\r+', ' ', x) + x = re.sub(r'\r+', ' ', x) + x = x.strip() + x = detokenizer.detokenize([x]) + return x + + +def detokenize(x, detokenizer): + x = x.strip() + # NOTE: Commenting this out since sacremoses seems to remove \n as part of detokenization. + # x = detokenizer.detokenize([x]) + return x + + +def is_empty(x, tokenizer): + return len(tokenizer.text_to_tokens(x.strip())) < 1 + + +def write_dataset_to_file(file_name, output_file_name, detokenizer, tokenizer, idx, total_num_files, remove_newline): + print(f'Processing file {idx + 1}/{total_num_files} : {file_name} -> {output_file_name}') + dataset = json.load(open(file_name, 'r')) + with open(output_file_name, 'w') as f: + instances = dataset['Instances'] + definitions = dataset['Definition'] + for definition in definitions: + if is_empty(definition, tokenizer): + continue + for instance in instances: + id = instance['id'] + input = instance['input'] + outputs = instance['output'] + # On rare occasions, the same instance can have multiple outputs. We add all of them as examples. + if is_empty(input, tokenizer): + continue + for output in outputs: + if is_empty(output, tokenizer): + continue + if remove_newline: + prompted_input = definition + ' ' + input + else: + prompted_input = definition + '\n\n' + input + proc_func = remove_newline_and_detokenize if remove_newline else detokenize + prompted_input = proc_func(prompted_input, detokenizer) + output = proc_func(output, detokenizer) + instance_object = { + 'id': id, + 'input': prompted_input, + 'output': output, + } + f.write(json.dumps(instance_object) + '\n') + + +def process_folder(data_folder, output_folder, splits_file, remove_newline): + detokenizer = MosesDetokenizer('en') + tokenizer = AutoTokenizer("gpt2") + assert os.path.isdir(data_folder) + assert os.path.exists(splits_file) + if not os.path.exists(output_folder): + os.system(f'mkdir -p {output_folder}') + if not os.path.exists(os.path.join(output_folder, 'train')): + os.system(f'mkdir -p {os.path.join(output_folder, "train")}') + if not os.path.exists(os.path.join(output_folder, 'test')): + os.system(f'mkdir -p {os.path.join(output_folder, "test")}') + + splits_file_names = [line.strip() + '.json' for line in open(splits_file, 'r')] + print(f'Found {len(os.listdir(data_folder))} files in the data folder ...') + print(f'Found {len(splits_file_names)} in the splits in the splits file ...') + print(f'Processing {len(splits_file_names)}/{len(os.listdir(data_folder))} files ...') + pool_args = [] + for idx, file_name in enumerate(splits_file_names): + print(f'Processing file {idx}/{len(splits_file_names)}: {file_name}') + if not os.path.exists(os.path.join(data_folder, file_name)): + raise FileNotFoundError(f'Could not find {os.path.join(data_folder, file_name)}') + if not file_name.endswith('.json'): + print(f'Skipping {file_name} because it is not a JSON file') + output_file_name = os.path.join(output_folder, file_name.replace('.json', '.jsonl')) + pool_args.append( + ( + os.path.join(data_folder, file_name), + output_file_name, + detokenizer, + tokenizer, + idx, + len(splits_file_names), + remove_newline, + ) + ) + + write_dataset_to_file( + os.path.join(data_folder, file_name), + output_file_name, + detokenizer, + tokenizer, + idx, + len(splits_file_names), + remove_newline, + ) + pool = Pool(42) + pool.starmap(write_dataset_to_file, pool_args) + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument( + "--niv2_dataset_path", + type=str, + required=True, + help="Path to raw P3 data. Should be a folder containing folders for each task. After cloning the repo this should correspond to P3/data", + ) + parser.add_argument( + "--jsonl_output_path", + type=str, + required=True, + help="Path to output folder where JSONL files will be written.", + ) + parser.add_argument( + "--splits_file_path", type=str, default="default", help="Path to the file that contains splits. ex: ", + ) + parser.add_argument( + "--remove_newline", action="store_true", help="Whether to remove newlines from the input and output.", + ) + args = parser.parse_args() + process_folder(args.niv2_dataset_path, args.jsonl_output_path, args.splits_file_path, args.remove_newline) diff --git a/scripts/nlp_language_modeling/t0/merge_train_tasks.py b/scripts/nlp_language_modeling/t0/merge_train_tasks.py index e28c6ac6150a..10bad27db002 100644 --- a/scripts/nlp_language_modeling/t0/merge_train_tasks.py +++ b/scripts/nlp_language_modeling/t0/merge_train_tasks.py @@ -96,6 +96,12 @@ def merge_train_folder(train_data_folder, merged_train_data_folder): for line in f: line = json.loads(line) line['task_name_with_prompt'] = fname + if line['input'].strip() == '': + print(f'WARNING: Empty input for {fname}') + continue + if line['output'].strip() == '': + print(f'WARNING: Empty output for {fname}') + continue fptrs[task].write(json.dumps(line) + '\n') if not found: print(f'WARNING: Could not find task for {fname}') diff --git a/scripts/nlp_language_modeling/t0/t0_dataset_preproc.py b/scripts/nlp_language_modeling/t0/t0_dataset_preproc.py index 2fec90bf261d..618c02c0cc13 100644 --- a/scripts/nlp_language_modeling/t0/t0_dataset_preproc.py +++ b/scripts/nlp_language_modeling/t0/t0_dataset_preproc.py @@ -58,17 +58,20 @@ def _feature_config(shape, dtype): return tf.io.FixedLenFeature(shape, dtype) -def remove_newline_and_detokenize(x, detokenizer): - x = re.sub(r'\\n+', ' ', x) - x = re.sub(r'\n+', ' ', x) - x = re.sub(r'\\r+', ' ', x) - x = re.sub(r'\r+', ' ', x) +def remove_newline_and_detokenize(x, detokenizer, remove_newlines): + if remove_newlines: + x = re.sub(r'\\n+', ' ', x) + x = re.sub(r'\n+', ' ', x) + x = re.sub(r'\\r+', ' ', x) + x = re.sub(r'\r+', ' ', x) x = x.strip() - x = detokenizer.detokenize([x]) + # NOTE: Moving the detokenizer inside this condition since sacremoses detokenize seems to remove \n as well. + if remove_newlines: + x = detokenizer.detokenize([x]) return x -def write_dataset_to_file(dataset, filename, detokenizer): +def write_dataset_to_file(dataset, filename, detokenizer, remove_newlines): with open(filename, 'w') as f: for item in dataset: # NOTE: Although we do `.tolist()` here this is not actually a list. This is just to convert from a numpy to python object so we can check if it is True/False. @@ -77,20 +80,24 @@ def write_dataset_to_file(dataset, filename, detokenizer): continue item_object = {} - i = remove_newline_and_detokenize(item['inputs_pretokenized'].numpy().decode('utf-8'), detokenizer) + i = remove_newline_and_detokenize( + item['inputs_pretokenized'].numpy().decode('utf-8'), detokenizer, remove_newlines + ) item_object['input'] = i - t = remove_newline_and_detokenize(item['targets_pretokenized'].numpy().decode('utf-8'), detokenizer) + t = remove_newline_and_detokenize( + item['targets_pretokenized'].numpy().decode('utf-8'), detokenizer, remove_newlines + ) item_object['output'] = t if 'answer_choices' in item: choices = [ - remove_newline_and_detokenize(x.decode('utf-8'), detokenizer) + remove_newline_and_detokenize(x.decode('utf-8'), detokenizer, remove_newlines) for x in item['answer_choices'].numpy().tolist() ] item_object['choices'] = choices f.write(json.dumps(item_object) + '\n') -def write_train_val_test_dataset_to_file(file_name, folder_name, output_folder, detokenizer, split): +def write_train_val_test_dataset_to_file(file_name, folder_name, output_folder, detokenizer, split, remove_newlines): ds = tf.data.TFRecordDataset(tf.io.gfile.glob([file_name])) fdict = _TASK_SPLITS_AND_FEATURES_DICT[folder_name]['features_dict'] feature_description = {feat: _feature_config(**desc) for feat, desc in fdict.items()} @@ -102,10 +109,10 @@ def write_train_val_test_dataset_to_file(file_name, folder_name, output_folder, lambda x: {k: tf.cast(v, fdict[k]["dtype"]) for k, v in x.items()}, num_parallel_calls=tf.data.experimental.AUTOTUNE, ) - write_dataset_to_file(ds, os.path.join(output_folder, split, folder_name + '.jsonl'), detokenizer) + write_dataset_to_file(ds, os.path.join(output_folder, split, folder_name + '.jsonl'), detokenizer, remove_newlines) -def process_folder(data_folder, folder_name, output_folder, detokenizer): +def process_folder(data_folder, folder_name, output_folder, detokenizer, remove_newlines): if not os.path.isdir(os.path.join(data_folder, folder_name)): return print(f'Processing {folder_name}') @@ -115,14 +122,20 @@ def process_folder(data_folder, folder_name, output_folder, detokenizer): if not os.path.exists(train_fname): print(f'Could not find {train_fname}') return - write_train_val_test_dataset_to_file(train_fname, folder_name, output_folder, detokenizer, 'train') + write_train_val_test_dataset_to_file( + train_fname, folder_name, output_folder, detokenizer, 'train', remove_newlines + ) if os.path.exists(valid_fname): - write_train_val_test_dataset_to_file(valid_fname, folder_name, output_folder, detokenizer, 'val') + write_train_val_test_dataset_to_file( + valid_fname, folder_name, output_folder, detokenizer, 'val', remove_newlines + ) if os.path.exists(test_fname): - write_train_val_test_dataset_to_file(test_fname, folder_name, output_folder, detokenizer, 'test') + write_train_val_test_dataset_to_file( + test_fname, folder_name, output_folder, detokenizer, 'test', remove_newlines + ) -def process_all_folders(data_folder, output_folder): +def process_all_folders(data_folder, output_folder, remove_newlines): detokenizer = MosesDetokenizer('en') assert os.path.isdir(data_folder) if not os.path.exists(output_folder): @@ -137,7 +150,7 @@ def process_all_folders(data_folder, output_folder): print(f'Found {len(os.listdir(data_folder))} folders to process ...') pool_args = [] for folder_name in os.listdir(data_folder): - pool_args.append((data_folder, folder_name, output_folder, detokenizer)) + pool_args.append((data_folder, folder_name, output_folder, detokenizer, remove_newlines)) pool = Pool() pool.starmap(process_folder, pool_args) @@ -156,5 +169,8 @@ def process_all_folders(data_folder, output_folder): required=True, help="Path to output folder where JSONL files will be written.", ) + parser.add_argument( + "--remove_newlines", action="store_true", help="Whether to remove newlines from the input and output.", + ) args = parser.parse_args() - process_all_folders(args.p3_dataset_path, args.jsonl_output_path) + process_all_folders(args.p3_dataset_path, args.jsonl_output_path, args.remove_newlines)