Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add IA3 prompt tuning #2

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion configs/local_setup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@
"log-dir": "logs",
"use_wandb": True,
"wandb_host": "https://api.wandb.ai",
"wandb_project": "neox"
"wandb_project": "neox",
"num_gpus": 1,
"ia3_prompt_tuning": True
maw501 marked this conversation as resolved.
Show resolved Hide resolved
}
1 change: 1 addition & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def load_checkpoint(
load_optimizer_states=load_optim_and_scheduler,
load_lr_scheduler_states=load_optim_and_scheduler,
tag=tag,
load_module_strict=False
)

if checkpoint_name is None:
Expand Down
168 changes: 166 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Transformer."""

import math
import sys
import torch
import torch.nn.functional as F
import torch.nn as nn
Expand Down Expand Up @@ -93,7 +94,9 @@ def __init__(
if self.activation_type == "geglu"
else ff_mult * neox_args.hidden_size
)
self.dense_h_to_4h = mpu.ColumnParallelLinear(
mlp_column_parallel_cls = getattr(mpu, neox_args.mlp_column_parallel_cls)

self.dense_h_to_4h = mlp_column_parallel_cls(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
Expand Down Expand Up @@ -590,6 +593,166 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
return output, bias


class ParallelSelfAttentionIA3(ParallelSelfAttention):
def __init__(
self,
neox_args,
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
rpe=None,
rotary=False,
use_cache=False,
parallel_output=False,
):
super().__init__(
neox_args,
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
rpe=rpe,
rotary=rotary,
use_cache=use_cache,
parallel_output=parallel_output,
)
self.l_k = self._create_ia3_parameter(neox_args)
self.l_v = self._create_ia3_parameter(neox_args)

def _create_ia3_parameter(self, neox_args):
if neox_args.use_cpu_initialization:
param = torch.nn.Parameter(
torch.empty(
self.hidden_size_per_partition, dtype=neox_args.params_dtype
)
)
else:
param = torch.nn.Parameter(
torch.empty(
self.hidden_size_per_partition,
device=torch.cuda.current_device(),
dtype=neox_args.params_dtype,
)
)
param.model_parallel = True
param.partition_dim = 0
#param.stride = stride
maw501 marked this conversation as resolved.
Show resolved Hide resolved
# Always initialize to ones.
with torch.no_grad():
torch.nn.init.ones_(param)
return param

def forward(self, hidden_states, attention_mask, layer_past=None):

# hidden_states: [sq, b, h]

# =====================
# Query, Key, and Value
# =====================

# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(
mixed_x_layer, 3
)
# Apply IA3 rescaling to keys & values:
def _apply_ia3_rescaling(layer, scale_vector):
layer_size = layer.shape
layer = layer.reshape(layer_size[0], layer_size[1], -1)
maw501 marked this conversation as resolved.
Show resolved Hide resolved
layer *= scale_vector
return layer.reshape(layer_size)

key_layer = _apply_ia3_rescaling(key_layer, self.l_k)
value_layer = _apply_ia3_rescaling(value_layer, self.l_v)

if exists(self.rotary_emb):
if exists(self.rotary_ndims):
# partial rotary
query_rot, query_pass = (
query_layer[..., : self.rotary_ndims],
query_layer[..., self.rotary_ndims :],
)
key_rot, key_pass = (
key_layer[..., : self.rotary_ndims],
key_layer[..., self.rotary_ndims :],
)
else:
# full rotary
query_rot, key_rot = query_layer, key_layer
apply_rotary_fn = (
apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
)

seq_len = key_layer.shape[0]
offset = 0
if exists(layer_past) and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_rotary_fn(
query_rot, key_rot, cos, sin, offset=offset
)

if exists(self.rotary_ndims):
query_layer = torch.cat((query_layer, query_pass), dim=-1)
key_layer = torch.cat((key_layer, key_pass), dim=-1)

# ==================================
# Cache key and value for inference
# ==================================

if exists(layer_past) and layer_past.numel() > 0:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0)
value_layer = torch.cat(
(past_value.type_as(value_layer), value_layer), dim=0
)

if self.use_cache:
present = torch.stack((key_layer, value_layer))

if self.use_flash_attention:
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
elif not self.sparse:
context_layer = self.attention(
query_layer, key_layer, value_layer, layer_past, attention_mask
)
else:
context_layer = self.sparse_attention(
query_layer, key_layer, value_layer, attention_mask
)

# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_partition,
)
context_layer = context_layer.view(*new_context_layer_shape)

# =================
# Output. [sq, b, h]
# =================

output, bias = self.dense(context_layer)

if self.use_cache:
output = [output, present]

return output, bias


class ParallelTransformerLayer(nn.Module):
"""A single transformer layer.

Expand Down Expand Up @@ -625,9 +788,10 @@ def __init__(

if self.gpt_j_residual:
self.reduce = mpu.mappings.reduce_from_model_parallel_region
self_attention_cls = getattr(sys.modules[__name__], neox_args.self_attention_cls)

# Self attention.
self.attention = ParallelSelfAttention(
self.attention = self_attention_cls(
neox_args=neox_args,
attention_mask_func=attention_mask_func,
init_method=init_method,
Expand Down
1 change: 1 addition & 0 deletions megatron/mpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .initialize import model_parallel_is_initialized

from .layers import ColumnParallelLinear
from .layers import ColumnParallelLinearIA3
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .layers import ParallelRelativePositionBias
Expand Down
65 changes: 65 additions & 0 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,71 @@ def forward(self, input_):
return output, output_bias


class ColumnParallelLinearIA3(ColumnParallelLinear):
def __init__(
self,
neox_args,
input_size,
output_size,
bias=True,
gather_output=True,
init_method=init.xavier_normal_,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
mup_rescale_parameters=False,
):
super().__init__(
neox_args,
input_size,
output_size,
bias=bias,
gather_output=gather_output,
init_method=init_method,
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
skip_bias_add=skip_bias_add,
mup_rescale_parameters=mup_rescale_parameters
)
if neox_args.use_cpu_initialization:
self.l_ff = Parameter(
torch.empty(
self.output_size_per_partition, dtype=neox_args.params_dtype
)
)
else:
self.l_ff = Parameter(
torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=neox_args.params_dtype,
)
)
self.l_ff.model_parallel = True
self.l_ff.partition_dim = 0
self.l_ff.stride = stride
# Always initialize l_ff to ones.
with torch.no_grad():
torch.nn.init.ones_(self.l_ff)

def forward(self, input_):
if self.use_mup and self.mup_rescale_parameters:
input_ /= self.width_mult()
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.

bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
output_parallel *= self.l_ff # apply IA3 rescaling
maw501 marked this conversation as resolved.
Show resolved Hide resolved
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias

class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.

Expand Down
16 changes: 16 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,27 @@ class NeoXArgsModel(NeoXArgsTemplate):
"""

output_layer_parallelism: Literal["row", "column"] = "row"
ia3_prompt_tuning: bool = False
"""
Run IA3 prompt tuning based off:
Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning
https://arxiv.org/pdf/2205.05638.pdf
"""

"""
Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)
"""

self_attention_cls: str = "ParallelSelfAttention"
"""
Default class to use for self attention
"""

mlp_column_parallel_cls: str = "ColumnParallelLinear"
"""
Default class to use for linear column layer parallelism
"""


@dataclass
class NeoXArgsOptimizer(NeoXArgsTemplate):
Expand Down
14 changes: 14 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ def get_model(neox_args, use_cache=False):
# If mup isn't being used anyways, this has no effect.
old_use_mup = neox_args.use_mup
neox_args.use_mup = False
if neox_args.ia3_prompt_tuning:
neox_args.mlp_column_parallel_cls = "ColumnParallelLinearIA3"
neox_args.self_attention_cls = "ParallelSelfAttentionIA3"

model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
Expand Down Expand Up @@ -412,6 +416,16 @@ def get_model(neox_args, use_cache=False):
for name, param in model.named_parameters():
if not "soft_embedding" in name:
param.requires_grad = False
elif neox_args.ia3_prompt_tuning:
layers_to_train = ["l_ff", "l_k", "l_v"]
for name, param in model.named_parameters():
if not any([x in name for x in layers_to_train]):
param.requires_grad = False

trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
print(f"Number of trainable parameters: {trainable_params}")

if not neox_args.is_pipe_parallel:
# Export PipeParallel model to nn.Sequential model to avoid the overhead of deepspeed's pipe parallel training
Expand Down