Skip to content

Commit

Permalink
Add BloomModel hydra support (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-tow authored Dec 9, 2022
1 parent 33deeb1 commit 1a3461d
Showing 1 changed file with 203 additions and 5 deletions.
208 changes: 203 additions & 5 deletions trlx/model/nn/ppo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import transformers
from torchtyping import TensorType
from transformers.modeling_outputs import ModelOutput
from transformers.models.opt.modeling_opt import _make_causal_mask, _expand_mask
from transformers.models.bloom import modeling_bloom
from transformers.models.opt import modeling_opt

from trlx.data.method_configs import MethodConfig, register_method
from trlx.utils.modeling import (
Expand Down Expand Up @@ -698,15 +699,15 @@ def forward(
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
combined_attention_mask = modeling_opt._make_causal_mask(
input_shape,
hidden_states.dtype,
past_key_values_length=past_key_values_length,
).to(hidden_states.device)

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
expanded_attn_mask = modeling_opt._expand_mask(
attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]
).to(hidden_states.device)
combined_attention_mask = (
Expand Down Expand Up @@ -798,6 +799,193 @@ def forward(
)


class BloomModelBranch(transformers.PreTrainedModel):
"""
BloomModelBranch implements the frozen upper trunk of the reference model
used when computing the PPO KL-divergence penalty. Expects a list of
frozen transformer blocks and an lm_head from the base model.
"""

def __init__(
self,
config: transformers.PretrainedConfig,
transformer_blocks: nn.ModuleList,
final_norm: nn.Module,
lm_head: nn.Module,
):
super().__init__(config)

# Defined by the main trunk
self.hidden_size = hf_get_hidden_size(config)
self.transformer_blocks = deepcopy(nn.ModuleList(transformer_blocks))
self.final_norm = deepcopy(final_norm)
self.lm_head = deepcopy(lm_head)

# Model parallel
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False

# Turning off grad saves memory
for block in self.transformer_blocks:
for parameter in block.parameters():
parameter.requires_grad = False
for parameter in lm_head.parameters():
parameter.requires_grad = False

def forward(
self,
hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids
output_shape: torch.Tensor, # output_size given by main trunk
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = False,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

#######################################################################
# Modififed BloomModel.forward
#######################################################################

batch_size, seq_length = hidden_states.shape[:2]

if past_key_values is None:
past_key_values = tuple([None] * len(self.transformer_blocks))

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config))

presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None

# Compute alibi tensor: check modeling_bloom.build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), device=hidden_states.device
)
else:
attention_mask = attention_mask.to(hidden_states.device)

alibi = modeling_bloom.build_alibi_tensor(
attention_mask, self.config.n_head, dtype=hidden_states.dtype
)

# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
input_shape = (batch_size, seq_length)
_, src_length = input_shape

if src_length > 1:
combined_attention_mask = modeling_bloom._make_causal_mask(
input_shape,
device=device,
past_key_values_length=past_key_values_length,
)

# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
expanded_attn_mask = modeling_bloom._expand_mask(
attention_mask, tgt_length=src_length
)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
causal_mask = combined_attention_mask

for i, (block, layer_past) in enumerate(
zip(self.transformer_blocks, past_key_values)
):

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)

hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)

if output_attentions:
all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)

# Add last hidden state
hidden_states = self.final_norm(hidden_states)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

#######################################################################
# End of modified BloomModel.forward
#######################################################################

lm_logits = self.lm_head(hidden_states)

if not return_dict:
return tuple(
v
for v in [
lm_logits,
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
)

return CausalLMOutputWithCrossAttentions(
loss=None,
logits=lm_logits,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=None,
value=None,
)


def hf_get_causal_lm_branch_class(
config: transformers.PretrainedConfig,
) -> "ModelBranch":
Expand All @@ -809,14 +997,24 @@ def hf_get_causal_lm_branch_class(
"GPTNeoXForCausalLM",
]
opt_branch_supported_archs = ["OPTForCausalLM"]
bloom_branch_supported_archs = ["BloomModel", "BloomForCausalLM"]
arch = config.architectures[0]
if arch in gpt_branch_supported_archs:
return GPTModelBranch
elif arch in opt_branch_supported_archs:
return OPTModelBranch
elif arch in bloom_branch_supported_archs:
return BloomModelBranch
else:
all_supported_archs = sum(
[
gpt_branch_supported_archs,
opt_branch_supported_archs,
bloom_branch_supported_archs,
],
[],
)
raise ValueError(
f"Unsupported architecture: `{arch}`. The following architectures are "
"available for model branching:\n"
f"{gpt_branch_supported_archs + opt_branch_supported_archs}"
"available for model branching:\n{all_supported_archs}"
)

0 comments on commit 1a3461d

Please sign in to comment.