Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T5 to delta modifier map #234

Merged
merged 12 commits into from
Feb 6, 2023
Merged

Conversation

aaronrmm
Copy link
Contributor

This adds an entry in utils.modeling.py::MOIDIFIED_MODULES_DICT for T5 as per
#173 (comment)

This is a guess using other modules as examples and by visualizing a T5 model with bigmodelvis.Visualization:

from bigmodelvis import Visualization
from transformers import AutoModelForSeq2SeqLM
backbone_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl")
model_vis = Visualization(backbone_model)
graph = model_vis.structure_graph()
print(graph)

I can run it, though have not been able to do a full training job without running out of memory.

print(graph)output:

root
├── shared(Embedding),lm_head(Linear) weight:[32128, 2048]
├── encoder (T5Stack)
│ ├── embed_tokens (Embedding) weight:[32128, 2048]
│ ├── block (ModuleList)
│ │ ├── 0 (T5Block)
│ │ │ └── layer (ModuleList)
│ │ │ ├── 0 (T5LayerSelfAttention)
│ │ │ │ ├── SelfAttention (T5Attention)
│ │ │ │ │ ├── q,k,v,o(Linear) weight:[2048, 2048]
│ │ │ │ │ └── relative_attention_bias (Embedding) weight:[32, 32]
│ │ │ │ └── layer_norm (T5LayerNorm) weight:[2048]
│ │ │ └── 1 (T5LayerFF)
│ │ │ ├── DenseReluDense (T5DenseGatedActDense)
│ │ │ │ ├── wi_0,wi_1(Linear) weight:[5120, 2048]
│ │ │ │ └── wo (Linear) weight:[2048, 5120]
│ │ │ └── layer_norm (T5LayerNorm) weight:[2048]
│ │ └── 1-23(T5Block)
│ │ └── layer (ModuleList)
│ │ ├── 0 (T5LayerSelfAttention)
│ │ │ ├── SelfAttention (T5Attention)
│ │ │ │ └── q,k,v,o(Linear) weight:[2048, 2048]
│ │ │ └── layer_norm (T5LayerNorm) weight:[2048]
│ │ └── 1 (T5LayerFF)
│ │ ├── DenseReluDense (T5DenseGatedActDense)
│ │ │ ├── wi_0,wi_1(Linear) weight:[5120, 2048]
│ │ │ └── wo (Linear) weight:[2048, 5120]
│ │ └── layer_norm (T5LayerNorm) weight:[2048]
│ └── final_layer_norm (T5LayerNorm) weight:[2048]
└── decoder (T5Stack)
├── embed_tokens (Embedding) weight:[32128, 2048]
├── block (ModuleList)
│ ├── 0 (T5Block)
│ │ └── layer (ModuleList)
│ │ ├── 0 (T5LayerSelfAttention)
│ │ │ ├── SelfAttention (T5Attention)
│ │ │ │ ├── q,k,v,o(Linear) weight:[2048, 2048]
│ │ │ │ └── relative_attention_bias (Embedding) weight:[32, 32]
│ │ │ └── layer_norm (T5LayerNorm) weight:[2048]
│ │ ├── 1 (T5LayerCrossAttention)
│ │ │ ├── EncDecAttention (T5Attention)
│ │ │ │ └── q,k,v,o(Linear) weight:[2048, 2048]
│ │ │ └── layer_norm (T5LayerNorm) weight:[2048]
│ │ └── 2 (T5LayerFF)
│ │ ├── DenseReluDense (T5DenseGatedActDense)
│ │ │ ├── wi_0,wi_1(Linear) weight:[5120, 2048]
│ │ │ └── wo (Linear) weight:[2048, 5120]
│ │ └── layer_norm (T5LayerNorm) weight:[2048]
│ └── 1-23(T5Block)
│ └── layer (ModuleList)
│ ├── 0 (T5LayerSelfAttention)
│ │ ├── SelfAttention (T5Attention)
│ │ │ └── q,k,v,o(Linear) weight:[2048, 2048]
│ │ └── layer_norm (T5LayerNorm) weight:[2048]
│ ├── 1 (T5LayerCrossAttention)
│ │ ├── EncDecAttention (T5Attention)
│ │ │ └── q,k,v,o(Linear) weight:[2048, 2048]
│ │ └── layer_norm (T5LayerNorm) weight:[2048]
│ └── 2 (T5LayerFF)
│ ├── DenseReluDense (T5DenseGatedActDense)
│ │ ├── wi_0,wi_1(Linear) weight:[5120, 2048]
│ │ └── wo (Linear) weight:[2048, 5120]
│ └── layer_norm (T5LayerNorm) weight:[2048]
└── final_layer_norm (T5LayerNorm) weight:[2048]

fix: remove attention_with_bias delta set for T5
Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this @aaronrmm !

Right now, no LoRA adapters are injected into the model:

image

From a quick skim, it looks as though we'll need to modify the existing regex pattern to capture T5 layers as they're not in the same format as the rest of the causal models we currently have. See the existing regex pattern and modifier handler below:

trlx/trlx/utils/modeling.py

Lines 334 to 360 in b2ce1a4

def generate_layer_regex(
config: transformers.PretrainedConfig, num_layers_unfrozen: int = -1
) -> str:
"""Generates a regex range for the specified number of learnable layers."""
if num_layers_unfrozen == -1:
return "[r](\d)+."
num_hidden_layers = hf_get_num_hidden_layers(config)
start_layer = num_hidden_layers - num_layers_unfrozen
if start_layer < 0:
raise Exception(
"Number of layers unfrozen cannot be greater than number of layers in the model"
)
pattern = f"(?:{regex_for_range(start_layer, num_hidden_layers - 1)})."
return f"[r]{pattern}"
def get_delta_modified_modules(
config: transformers.PretrainedConfig,
modified_modules: List[str],
num_layers_unfrozen: int = -1,
) -> List[str]:
"""Returns a list of module names to be modified for a given delta method with
the specified number of learnable layers."""
prefix = generate_layer_regex(config, num_layers_unfrozen)
module_list = [prefix + module for module in modified_modules]
return module_list

Note: We should only freeze decoder blocks because these are the only learnable layers of our seq2seq models; see

def freeze_bottom_seq2seq_layers(model: nn.Module, num_layers_unfrozen: int = 0):
"""Freezes the bottom transformer block layers of the specified model."""
if num_layers_unfrozen == -1:
return
shared_embed = model.shared
decoder_embed = model.decoder.embed_tokens
encoder_blocks = model.encoder.block
encoder_norm_layer = model.encoder.final_layer_norm
decoder_norm_layer = model.decoder.final_layer_norm
decoder_blocks = model.decoder.block[:-num_layers_unfrozen]
blocks_to_freeze = (
list(encoder_blocks)
+ list(decoder_blocks)
+ [shared_embed]
+ [encoder_norm_layer]
+ [decoder_norm_layer]
+ [decoder_embed]
)
for block in blocks_to_freeze:
block.requires_grad_(False)

@@ -97,6 +97,8 @@ def test_hf_attr_getters(model_name: str):
"EleutherAI/gpt-neox-20b",
"facebook/opt-1.3b",
"bigscience/bloom-560m",
"google/flan-t5-large",
"google/flan-t5-xxl"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the "google/flan-t5-xxl" test case since it's the same architecture as "google/flan-t5-large"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jon-tow
@ethankim00 helped me with the layer naming we changed it to:

        "all": [
            "layer.0.SelfAttention.q",
            "layer.0.SelfAttention.k",
            "layer.0.SelfAttention.v",
            "layer.0.SelfAttention.o",
            "layer.1.EncDecAttention.q",
            "layer.1.EncDecAttention.k",
            "layer.1.EncDecAttention.v",
            "layer.1.EncDecAttention.o",
            "layer.1.DenseReluDense.wo",
            "layer.1.DenseReluDense.wi_0",
            "layer.1.DenseReluDense.wi_1",
            "layer.2.DenseReluDense.wo",
            "layer.2.DenseReluDense.wi_0",
            "layer.2.DenseReluDense.wi_1",
        ],

This is working in that I'm now getting more than zero parameters for those OpenDelta ratios.

But additionally, are you saying we should exclude encoder layers from this modified modules list? And that would require some changes to the regex pattern.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior for the causal (decoder-only) models is that only the upper trainable layers, determined by num_layers_unfrozen, are tagged with LoRA layers. If we want to be consistent, then for seq2seq models we need to also inject LoRA layers only into the trainable parts - the upper blocks of the decoder (see freeze_bottom_seq2seq_layers() above). This map seems to also modify layers in the encoder block. I think you may be able to just add a condition to modify the base_model.decoder for seq2seq archs and just base_model for causal in the line below.

delta_model = delta_model_class(model.base_model, **delta_kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I am caught up on context and pushed a potential solution, but your suggestion might be better. I just don't fully understand. Let me know if you want to discuss. @jon-tow

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aaronrmm This is perfect! I wasn't sure if we'd need to modify the call in the base trainer but you proved it unnecessary :)

@aaronrmm aaronrmm requested a review from jon-tow February 4, 2023 23:44
Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @aaronrmm!

@jon-tow jon-tow merged commit 6892fc3 into CarperAI:main Feb 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants