Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT GPT & T5 Refactor #7308

Merged
merged 167 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 137 commits
Commits
Show all changes
167 commits
Select commit Hold shift + click to select a range
66b0607
initial implementation of add_adapters API
cuichenx Aug 16, 2023
82fcfe8
correct type hint
cuichenx Aug 16, 2023
d8a13a3
Add config in add_adapters for save and load (@author bobchen)
cuichenx Aug 17, 2023
4b02f0d
Remove AdapterConfig to avoid import error
meatybobby Aug 18, 2023
06b46a1
Add AdaterConfig back and move adaptermixin to sft model
meatybobby Aug 19, 2023
63f4c74
Add NLPSaveRestoreConnector as default in NLPModel.restore_from
meatybobby Aug 22, 2023
08dbcee
Add restore_from_nemo_with_adapter and test script
meatybobby Aug 22, 2023
48f9bcc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2023
6a894fa
rename t5 file and classes to be consistent with GPT
cuichenx Aug 21, 2023
b44052b
add t5 sft dataset
cuichenx Aug 21, 2023
cfcd00d
add support for single-file format with T5SFTDataset
cuichenx Aug 21, 2023
fefcfe6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2023
5c2ab84
Various small changes to make T5 SFT work like GPT SFT
cuichenx Aug 24, 2023
b2eed6b
Merge remote-tracking branch 'origin/peft_refactor' into peft_refactor
cuichenx Aug 24, 2023
916729a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2023
a72b4ab
Add adapter evaluation test script
meatybobby Aug 24, 2023
4e0258e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2023
4eee5b6
Add MultiAdaterConfig for ia3 and fix builder issue
meatybobby Aug 25, 2023
0e72328
Make ptuning for T5SFTModel work using mixin
cuichenx Aug 25, 2023
79be201
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 25, 2023
2be50cd
Add IA3_Adapter for AdapterName
meatybobby Aug 25, 2023
0481d27
Add adapter name for ptuning and attention adapter
meatybobby Aug 28, 2023
cfd1105
Make test script GPT/T5 agnostic
cuichenx Aug 28, 2023
3f7d1ab
Add layer selection feature
cuichenx Aug 28, 2023
fd27447
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2023
5caf328
Merge branch 'main' into peft_refactor
meatybobby Aug 28, 2023
3a4a22d
Integrate adapter name and config
meatybobby Aug 28, 2023
5011227
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2023
219a5bb
update gpt peft tuning script to new API
cuichenx Aug 28, 2023
0ef1cab
add t5 peft tuning script with new API
cuichenx Aug 28, 2023
52b5dbb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2023
4ff60bc
Fix IA3 layer selection issue
meatybobby Aug 29, 2023
43a0da5
Override state_dict on SFT model instead of mixin
meatybobby Aug 29, 2023
fd178d2
Add load adapter by adapter config
meatybobby Aug 29, 2023
ee69afe
move peft config map away from example script
cuichenx Aug 29, 2023
0e34c3f
auto get config from nemo adapter
meatybobby Aug 29, 2023
fb8fc50
Move PEFTConfig to new file
meatybobby Aug 30, 2023
e671279
fix ckpt save/load for t5
cuichenx Aug 30, 2023
a0be911
name change: add_adapters -> add_adapter
cuichenx Aug 30, 2023
b398566
variable name change
cuichenx Aug 30, 2023
33d37a1
update t5 script
cuichenx Aug 30, 2023
7a06a46
Merge branch 'main' into peft_refactor
meatybobby Aug 30, 2023
1798554
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 30, 2023
922f887
fix t5 issues
cuichenx Aug 31, 2023
20ef4d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2023
2277552
Add weight tying
meatybobby Aug 31, 2023
b433c1d
update gpt tuning script
cuichenx Aug 31, 2023
ce0c5ba
PEFT-API proposal
marcromeyn Aug 30, 2023
37da440
Fix according to comments
marcromeyn Aug 31, 2023
de3c0b2
update tuning scripts
cuichenx Aug 31, 2023
811dd6d
move merge_cfg_with to mixin class since it applies to both gpt and t…
cuichenx Aug 31, 2023
0c50a57
Merge branch 'main' into peft_refactor
meatybobby Aug 31, 2023
50c4b67
Add mcore_gpt support for NLPAdapterMixin
meatybobby Aug 31, 2023
eb70660
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2023
f7fc5af
fix typo
cuichenx Aug 31, 2023
c25296d
variable name change to distinguish "peft" and "adapter"
cuichenx Aug 31, 2023
b2ea917
override `load_adapters` to support `add_adapter` name change
cuichenx Aug 31, 2023
125dc88
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2023
d1c5381
update tuning and eval script for adapter save/load
cuichenx Sep 1, 2023
21cc7c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2023
9ed5a17
Add Ptuning on first stage only
meatybobby Sep 1, 2023
c85daa5
add lora tutorial for review
cuichenx Sep 1, 2023
ff5dc2b
Fix layer selection for mcore
meatybobby Sep 1, 2023
5d90fb2
add landing page
cuichenx Sep 1, 2023
1fb3d8b
Merge remote-tracking branch 'origin/peft_refactor' into peft_refactor
cuichenx Sep 1, 2023
38788ef
fix resume training
blahBlahhhJ Sep 1, 2023
48ddfd7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2023
ead563c
add mcore condition in sharded_state_dict to make sft work
cuichenx Sep 1, 2023
b8f4676
Update lora_tutorial.md
hkelly33 Sep 1, 2023
a5be21b
rename Adapter to AttentionAdapter to avoid confusion in doc
cuichenx Sep 2, 2023
e408cb7
Change load_adapters to load .nemo
meatybobby Sep 2, 2023
4682c92
Merge remote-tracking branch 'origin/peft_refactor' into peft_refactor
cuichenx Sep 2, 2023
e121591
add quick start guide
cuichenx Sep 2, 2023
c928a34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 2, 2023
9f7fb91
Add load_adapters with .ckpt
meatybobby Sep 2, 2023
3c0e2ed
Remove setup_complete changes in load_adapters
meatybobby Sep 2, 2023
a63b188
update landing page
cuichenx Sep 2, 2023
c2cb621
remove typo
cuichenx Sep 2, 2023
ac625fd
Updated quick_start.md per Chen Cui
hkelly33 Sep 5, 2023
4cfbb8f
Add inference config merger and tutorial
meatybobby Sep 5, 2023
45346ca
Add doc string for NLPAdapterModelMixin and deprecated warning on Meg…
meatybobby Sep 6, 2023
ca96a5b
add supported_methods.md and update other documentations
cuichenx Sep 6, 2023
4ca143f
Update supported_methods.md
arendu Sep 6, 2023
6104da7
Update landing_page.md
arendu Sep 6, 2023
eb8365f
Modify doc string for NLPAdapterModelMixin
meatybobby Sep 6, 2023
c533832
Add doc string add_adapters in NLPAdapterModelMixin
meatybobby Sep 6, 2023
d1c753d
rename canonical adapters
cuichenx Sep 6, 2023
cd78076
remove mcore hard dependency
cuichenx Sep 7, 2023
0d3f61e
[PATCH] move microbatch calculator to nemo from apex
Sep 7, 2023
0236e3c
remove apex dependency in gpt and t5 sft models
cuichenx Sep 7, 2023
0cac299
remove apex dependency in gpt model
cuichenx Sep 7, 2023
1c09a5a
render doc strings
cuichenx Sep 8, 2023
b03fe5d
fix
cuichenx Sep 8, 2023
263d465
Add missing virtual_tokens on ptuning
meatybobby Sep 8, 2023
84c765e
fix docstrings
cuichenx Sep 8, 2023
e375eaa
update gpt-style model coverage in docs
cuichenx Sep 8, 2023
3717132
update docstring
cuichenx Sep 8, 2023
57136ed
Remove pdb
meatybobby Sep 8, 2023
ba20af0
add lightning_fabric to make docstring rendering work
cuichenx Sep 8, 2023
8ccaaaf
Add Ptuning missing key
meatybobby Sep 8, 2023
7e2ef24
try docstring rendering
cuichenx Sep 9, 2023
0b66e48
Fix ptuning issue
meatybobby Sep 9, 2023
dc267bd
update gpt t5 peft tuning and eval scripts
cuichenx Sep 11, 2023
303e9c9
typos
cuichenx Sep 11, 2023
9a9e417
update eval config
cuichenx Sep 11, 2023
490bdc9
fix bug relating to apex dependency removal
cuichenx Sep 12, 2023
af212dd
typo
cuichenx Sep 12, 2023
0ce6a63
make predict step behave the same as test step
cuichenx Sep 12, 2023
6f493b4
make lora tutorial work in notebook
cuichenx Sep 12, 2023
49a8fa7
cosmetics
cuichenx Sep 12, 2023
ca63fc5
update yaml scripts
cuichenx Sep 12, 2023
a1d577a
Merge remote-tracking branch 'origin/peft_refactor' into peft_refactor
cuichenx Sep 12, 2023
8340695
mcore_gpt attribute optional
cuichenx Sep 12, 2023
d470fed
typo
cuichenx Sep 12, 2023
952ca53
update eval scripts and fix T5 eval bugs
cuichenx Sep 12, 2023
fb4e3d9
add NLPDDPStrategyNotebook and trainer builder logic to use it
cuichenx Sep 13, 2023
5440847
update lora notebook to use new trainer builder
cuichenx Sep 13, 2023
c2cc936
fix microbatch calculator bug for inference after training
cuichenx Sep 13, 2023
d134f03
Convert markdown files to RST and incorporate with doc
cuichenx Sep 13, 2023
08755a0
typo
cuichenx Sep 18, 2023
4f414db
revise language
cuichenx Sep 18, 2023
1276078
remove extra cell
cuichenx Sep 18, 2023
64ecf05
remove unnecessary inheritance
cuichenx Sep 18, 2023
833ad14
remove old tests
cuichenx Sep 18, 2023
bf4d771
move layer selection default so logging messages make sense
cuichenx Sep 18, 2023
3f13f53
remove `save_adapters` as adapter weights are saved automatically dur…
cuichenx Sep 19, 2023
45d9f4b
initialize weights from a checkpoint instead of randomly
cuichenx Sep 19, 2023
8bc32ca
multiple fields can form a context (#7147)
arendu Sep 1, 2023
7d2458a
revert config changes
cuichenx Sep 21, 2023
7c16e7d
remove accidental breakpoint
cuichenx Sep 22, 2023
29d9197
support TP>1 loading
cuichenx Sep 22, 2023
cf82f6d
infer adapter type from checkpoint in during eval
cuichenx Sep 22, 2023
068d68b
breakup add adapter
cuichenx Sep 22, 2023
90bc80b
enable interpolation of train_ds and validation_ds
cuichenx Sep 22, 2023
d5017e6
update metric calc script to conform to single-file eval format
cuichenx Sep 22, 2023
05b89e2
remove extraneous print
cuichenx Sep 22, 2023
a44a34b
update lora notebook for updated merge_inference_cfg
cuichenx Sep 23, 2023
4eac726
Update nlp_adapter_mixins.py
cuichenx Sep 25, 2023
b7752ef
turn off grad scaler for PP to match old scripts
cuichenx Sep 25, 2023
745a57e
remove PEFTSaveRestoreConnector since functionality all covered by th…
cuichenx Sep 25, 2023
553647a
Merge remote-tracking branch 'origin/peft_refactor' into peft_refactor
cuichenx Sep 25, 2023
a9453de
remove resume_from_checkpoint check since covered in #7335
cuichenx Sep 25, 2023
03a6be2
revert changes made in eval config interpolation
cuichenx Sep 25, 2023
ffde138
more interpolation
cuichenx Sep 25, 2023
ceb4b2d
typo
cuichenx Sep 25, 2023
a2dd4a0
Merge branch 'main' into peft_refactor
cuichenx Sep 25, 2023
09cde67
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2023
50b5848
remove dup line
cuichenx Sep 25, 2023
ad56c7c
code style warnings
cuichenx Sep 25, 2023
af67d40
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2023
83f340e
fix config mistake
cuichenx Sep 25, 2023
ebf94f3
add copyright header
cuichenx Sep 26, 2023
33b3d92
fix code check warnings
cuichenx Sep 26, 2023
9c32b15
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2023
2db5fef
revert changes to remove apex dependency (mixed apex+nemo microbatch …
cuichenx Sep 26, 2023
53ed5fd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2023
838a7b7
add more deprecation notices
cuichenx Sep 26, 2023
cf3892f
update deprecation notices
cuichenx Sep 26, 2023
8e238fa
update deprecation notices
cuichenx Sep 26, 2023
dc0fe10
consolidate peft and sft scripts
cuichenx Sep 26, 2023
19b831f
update CI tests
cuichenx Sep 26, 2023
772f93b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2023
5817f15
notebook branch points to main to prepare for merge
cuichenx Sep 27, 2023
9353c6d
fix gpt and t5 validation with any metric other than loss
cuichenx Sep 27, 2023
aa3dba0
Merge branch 'main' into peft_refactor
cuichenx Sep 27, 2023
918f9ca
support pre-extracted checkpoints
cuichenx Sep 28, 2023
79a878e
Merge branch 'main' into peft_refactor
cuichenx Sep 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
'attr', # attrdict in requirements, attr in import
'torchmetrics', # inherited from PTL
'lightning_utilities', # inherited from PTL
'lightning_fabric',
'apex',
'megatron.core',
'transformer_engine',
Expand Down
14 changes: 12 additions & 2 deletions docs/source/nlp/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ Modules
.. autoclass:: nemo.collections.nlp.modules.common.megatron.module.Float16Module
:show-inheritance:


.. autoclass:: nemo.collections.nlp.models.language_modeling.megatron.gpt_model.GPTModel
:show-inheritance:
:no-members:
Expand Down Expand Up @@ -140,11 +139,22 @@ Datasets
.. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.ul2_dataset.UL2Dataset
:show-inheritance:


Adapter Mixin Class
-------------------------

.. autoclass:: nemo.collections.nlp.parts.mixins.nlp_adapter_mixins.NLPAdapterModelMixin
:show-inheritance:
:members: add_adapter, load_adapters, merge_cfg_with, merge_inference_cfg
:exclude-members: first_stage_of_pipeline, tie_weights, get_peft_state_dict, state_dict, sharded_state_dict, load_state_dict, on_load_checkpoint
:member-order: bysource


Exportable Model Classes
-------------------------

.. autoclass:: nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTExportableModel
:show-inheritance:
:show-inheritance:

.. toctree::
:maxdepth: 1
Expand Down
1 change: 1 addition & 0 deletions docs/source/nlp/nemo_megatron/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ team at NVIDIA. NeMo Megatron supports several types of models:
prompt_learning
retro/retro_model
hiddens/hiddens_module
peft/landing_page


References
Expand Down
35 changes: 35 additions & 0 deletions docs/source/nlp/nemo_megatron/peft/landing_page.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
Parameter-Efficient Fine-Tuning (PEFT)
======================================

PEFT is a popular technique used to efficiently finetune large language
models for use in various downstream tasks. When finetuning with PEFT,
the base model weights are frozen, and a few trainable adapter modules
are injected into the model, resulting in a very small number (<< 1%) of
trainble weights. With carefully chosen adapter modules and injection
points, PEFT achieves comparable performance to full finetuning at a
fraction of the computational and storage costs.

NeMo supports four PEFT methods which can be used with various
transformer-based models.

==================== ===== ===== ========= ==
\ GPT 3 NvGPT LLaMa 1/2 T5
==================== ===== ===== ========= ==
Adapters (Canonical) ✅ ✅ ✅ ✅
LoRA ✅ ✅ ✅ ✅
IA3 ✅ ✅ ✅ ✅
P-Tuning ✅ ✅ ✅ ✅
==================== ===== ===== ========= ==

Learn more about PEFT in NeMo with the :ref:`peftquickstart` which provides an overview on how PEFT works
in NeMo. Read about the supported PEFT methods
`here <supported_methods.html>`__. For a practical example, take a look at
the `Step-by-step Guide <https://github.com/NVIDIA/NeMo/blob/main/tutorials/nlp/lora.ipynb>`__.

The API guide can be found `here <../../api.html#adapter-mixin-class>`__

.. toctree::
:maxdepth: 1

quick_start
supported_methods
90 changes: 90 additions & 0 deletions docs/source/nlp/nemo_megatron/peft/quick_start.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
.. _peftquickstart:


Quick Start Guide
=================

The quick start guide provides an overview of a PEFT workflow in NeMo.

Terminology: PEFT vs Adapter
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This tutorial uses "PEFT" to describe the overall parameter efficient
finetuning method, and "adapter" to describe the additional module
injected to a frozen base model. Each PEFT model can use one or more
types of adapters.

One of the PEFT methods is sometimes referred to as "adapters", because
it was one of the first proposed usage of adapter modules for NLP. This
PEFT method will be called the "canonical" adapters to distinguish the
two usages.

How PEFT work in NeMo models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Each PEFT method has one or more types of adapters that need to be
injected into the base model. In NeMo models, the adapter logic and
adapter weights are already built into the submodules, but they are
disabled by default for ordinary training and fine-tuning.

When doing PEFT, the adapter logic path can be enabled when
``model.add_adapter(peft_cfg)`` is called. In this function, the model
scans through each adapter applicable to the current PEFT method with
each of its submodules in order to find adapter logic paths that can be
enabled. Then, the base models weights are frozen, while newly added
adapter weights are unfrozen and allowed to be updated during
fine-tuning, hence achieving efficiency in the number of parameters
finetuned.

PEFT config classes
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Each PEFT method is specified by a ``PEFTConfig`` class which stores the
types of adapters applicable to the PEFT method, as well as
hyperparameters required to initialize these adapter modules. These four
PEFT methods are currently supported:

1. Adapters (canonical): ``CanonicalAdaptersPEFTConfig``
2. LoRA: ``LoraPEFTConfig``
3. IA3: ``IA3PEFTConfig``
4. P-Tuning: ``PtuningPEFTConfig``

These config classes make experimenting with different adapters as easy
as changing the config class.

Moreover, it is possible to use a combination of the PEFT methods in
NeMo since they are orthogonal to each other. This can be easily done by
passing in a list of ``PEFTConfig`` objects to ``add_adapter`` instead
of a single one. For example, a common workflow is to combine P-Tuning
and Adapter, and this can be achieved with
``model.add_adapter([PtuningPEFTConfig(model_cfg), CanonicalAdaptersPEFTConfig(model_cfg)])``

Base model classes
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
PEFT in NeMo is built with a mix-in class that does not belong to any
model in particular. This means that the same interface is available to
different NeMo models. Currently, NeMo supports PEFT for GPT-style
models such as GPT 3, NvGPT, LLaMa 1/2 (``MegatronGPTSFTModel``), as
well as T5 (``MegatronT5SFTModel``).

Full finetuning vs PEFT
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
You can switch between full fine-tuning and PEFT by removing calls to
``add_adapter`` and ``load_adapter``.

The code snippet below illustrates the core API of full fine-tuning and
PEFT.

.. code:: diff

trainer = MegatronTrainerBuilder(config).create_trainer()
model_cfg = MegatronGPTSFTModel.merge_cfg_with(config.model.restore_from_path, config)

model = MegatronGPTSFTModel.restore_from(restore_path, model_cfg, trainer) # restore from pretrained ckpt
+ peft_cfg = LoRAPEFTConfig(model_cfg)
+ model.add_adapter(peft_cfg)
trainer.fit(model) # saves adapter weights only

# Restore from base then load adapter API
model = MegatronGPTSFTModel.restore_from(restore_path, trainer, model_cfg)
+ model.load_adapters(adapter_save_path, peft_cfg)
model.freeze()
trainer.predict(model)
71 changes: 71 additions & 0 deletions docs/source/nlp/nemo_megatron/peft/supported_methods.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@


Supported PEFT methods
----------------------

NeMo supports the following PFET tuning methods

1. **Adapters (Canonical)**: `Parameter-Efficient Transfer Learning for
NLP <http://arxiv.org/abs/1902.00751>`__

- Adapters (Houlsby setup) is one of the first PEFT methods applied
to NLP. Adapter tuning is more efficient than full fine-tuning
because the base model weights are frozen, while only a small
number of adapter module weights are updated. In this method, two
linear layers with a bottleneck and a non-linear activation are
inserted into each transformer layer via a residual connection. In
each case, the output linear layer is initialized to 0 to ensure
that an untrained adapter does not affect the normal forward pass
of the transformer layer.

2. **LoRA**: `LoRA: Low-Rank Adaptation of Large Language
Models <http://arxiv.org/abs/2106.09685>`__

- LoRA makes fine-tuning efficient by representing weight updates
with two low rank decomposition matrices. The original model
weights remain frozen, while the low rank decomposition matrices
are updated to adapt to the new data , so the number of trainable
parameters is kept low. In contrast with adapters, the original
model weights and adapted weights can be combined during
inference, avoiding any architectural change or additional latency
in the model at inference time.
- The matrix decomposition operation can be applied to any linear
layer, but in practice, it is only applied to the K, Q, V
projection matrices (sometimes just applied to the Q,V layers).
Since NeMo's attention implementation fuses KQV into a single
projection, our LoRA implementation learns a single Low-Rank
projection for KQV in a combined fashion.

3. **IA3**: `Few-Shot Parameter-Efficient Fine-Tuning is Better and
Cheaper than In-Context Learning <http://arxiv.org/abs/2205.05638>`__

- IA3 makes fine-tuning efficient by rescaling activations with
learned vectors. The rescaling layers are injected in the
attention (for key and value) and feedforward modules in the base
model. Similar to other PEFT methods, only the rescaling vectors
are updated during fine-tuning to adapt to the new data so the
number of updated parameters is low. However, since rescaling
vectors are much smaller than low rank matrices (LoRA) and
bottleneck layers (Adapters), IA3 cuts down the number of
trainable parameters further by an order of magnitude. The
learning rescaling vectors can also be merged with the base
weights, leading to no architectural change and no additional
latency at inference time.

4. **P-Tuning**: `GPT Understands,
Too <https://arxiv.org/abs/2103.10385>`__

- P-tuning is an example of the prompt learning family of methods,
in which trainable virtual tokens are inserted into the model
input prompt to induce it to perform a task. Virtual tokens (also
called "continuous" or "soft" tokens) are embeddings that have no
concrete mapping to strings or characters within the model’s
vocabulary. They are simply 1D vectors that match the
dimensionality of real tokens which make up the model's
vocabulary.
- In p-tuning, an intermediate LSTM or MLP model is used to generate
virtual token embeddings. We refer to this intermediate model as
our ``prompt_encoder``. The prompt encoder parameters are randomly
initialized at the start of p-tuning. All base model parameters
are frozen, and only the prompt encoder weights are updated at
each training step.
8 changes: 4 additions & 4 deletions examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin

from nemo.collections.nlp.models.language_modeling.megatron_finetune_model import MegatronT5FinetuneModel
from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel
from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model
from nemo.collections.nlp.models.language_modeling.megatron_t5_sft_model import MegatronT5SFTModel
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, MegatronHalfPrecisionPlugin, NLPDDPStrategy
from nemo.core.config import hydra_runner
from nemo.utils import logging
Expand Down Expand Up @@ -122,13 +122,13 @@ def main(cfg) -> None:
model = load_from_checkpoint_dir(MegatronT0Model, cfg, trainer, modify_confg_fn=_modify_config)
else:
if cfg.model.restore_from_path:
t5_cfg = MegatronT5FinetuneModel.restore_from(
t5_cfg = MegatronT5SFTModel.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
)
model = load_from_nemo(MegatronT5FinetuneModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
model = load_from_nemo(MegatronT5SFTModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
else:
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
model = load_from_checkpoint_dir(MegatronT5FinetuneModel, cfg, trainer, modify_confg_fn=_modify_config)
model = load_from_checkpoint_dir(MegatronT5SFTModel, cfg, trainer, modify_confg_fn=_modify_config)

model.freeze()
trainer.validate(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector

from nemo.collections.nlp.models.language_modeling.megatron_finetune_model import MegatronT5FinetuneModel
from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel
from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model
from nemo.collections.nlp.models.language_modeling.megatron_t5_sft_model import MegatronT5SFTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
Expand Down Expand Up @@ -206,13 +206,13 @@ def main(cfg) -> None:
model = load_from_checkpoint_dir(MegatronT0Model, cfg, trainer, modify_confg_fn=_modify_config)
else:
if cfg.model.restore_from_path:
t5_cfg = MegatronT5FinetuneModel.restore_from(
t5_cfg = MegatronT5SFTModel.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
)
model = load_from_nemo(MegatronT5FinetuneModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
model = load_from_nemo(MegatronT5SFTModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
else:
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
model = load_from_checkpoint_dir(MegatronT5FinetuneModel, cfg, trainer, modify_confg_fn=_modify_config)
model = load_from_checkpoint_dir(MegatronT5SFTModel, cfg, trainer, modify_confg_fn=_modify_config)

trainer.fit(model)
trainer.validate(model)
Expand Down
Loading
Loading