Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e62345e
bump versions
justheuristic Dec 12, 2022
88348cf
bump versions
justheuristic Dec 12, 2022
7018954
yeet models
justheuristic Dec 12, 2022
f46d421
y u no instal?
justheuristic Dec 12, 2022
1e2ab6b
fix imports
justheuristic Dec 12, 2022
4b35dd7
fix edge case where session crashes when receiving seq length 0
justheuristic Dec 12, 2022
67e8070
Merge branch 'main' into bump
borzunov Dec 12, 2022
062cd51
review
borzunov Dec 12, 2022
d227021
mixin
justheuristic Dec 12, 2022
ab813ba
remix
justheuristic Dec 12, 2022
9bf813b
fix throughput
justheuristic Dec 12, 2022
524468e
fix throughput
justheuristic Dec 12, 2022
c473012
benchmark throughput in CI jobs
justheuristic Dec 12, 2022
f9e0910
reduce ban timeout
borzunov Dec 12, 2022
044e915
fork pytest
justheuristic Dec 12, 2022
b12ad06
review
mryab Dec 13, 2022
35e2c0a
review
mryab Dec 13, 2022
27ac588
Update tests/test_aux_functions.py
justheuristic Dec 13, 2022
26fe612
Merge branch 'bump' of github.com:bigscience-workshop/petals into bump
justheuristic Dec 13, 2022
b090dd2
isort
mryab Dec 13, 2022
7f7f5dc
Update src/petals/server/handler.py
justheuristic Dec 13, 2022
7103268
Update src/petals/bloom/modeling_utils.py
justheuristic Dec 13, 2022
fa632a3
cleanup
mryab Dec 13, 2022
d1d59e4
Merge branch 'bump' of github.com:bigscience-workshop/petals into bump
justheuristic Dec 13, 2022
110a307
cleanup
justheuristic Dec 13, 2022
a24659f
Update src/petals/server/backend.py
justheuristic Dec 13, 2022
a61c5bb
check transformers version
justheuristic Dec 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ python_requires = >=3.7
install_requires =
torch>=1.12
bitsandbytes==0.34.0
accelerate==0.10.0
huggingface-hub==0.7.0
transformers==4.21.3
accelerate==0.15.0
huggingface-hub==0.11.1
transformers==4.25.1
protobuf>=3.20.3,<4.0dev
hivemind==1.1.3
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also gonna bump it, but it's a separate PR

humanfriendly
Expand Down
2 changes: 0 additions & 2 deletions src/petals/bloom/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
276 changes: 40 additions & 236 deletions src/petals/bloom/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,253 +3,57 @@
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
import math
import os
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
import transformers
from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor

from petals.bloom.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
build_alibi_tensor,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert transformers.__version__.startswith("4.25."), "Please install transformers 4.25.1"


class BloomAttention(nn.Module):
def __init__(self, config, layer_number=None):
super().__init__()

self.hidden_size = config.hidden_size
self.num_heads = config.n_head
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
self.masked_softmax_fusion = config.masked_softmax_fusion
self.hidden_dropout = config.hidden_dropout

if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)

# Layer-wise attention scaling
self.layer_number = max(1, layer_number)
self.norm_factor = math.sqrt(self.head_dim) * self.layer_number

# Scaled Softmax
self.scale_mask_softmax = BloomScaledSoftmax(
self.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
self.layer_number,
)

self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)

self.attention_dropout = nn.Dropout(config.attention_dropout)

class WrappedBloomBlock(BloomBlock):
def forward(
self,
hidden_states,
residual,
layer_past=None,
attention_mask=None,
alibi=None,
head_mask=None,
use_cache=False,
output_attentions=False,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs
):
assert attention_mask is None
batch_size, seq_length = hidden_states.shape[:2]
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None:
current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
alibi = build_alibi_tensor(
current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
)

# hidden_states: [batch_size, seq_length, hidden_size]
# apply preprocessing if the input is padded
if attention_mask is not None:
alibi = pre_process_alibi_for_pad(alibi, attention_mask)
# otherwise repeat alibi tensor with the batch size
else:
alibi = alibi.repeat(hidden_states.shape[0], 1, 1)

mixed_x_layer = self.query_key_value(hidden_states)

# [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)

if use_cache is True:
present = (key_layer, value_layer)
else:
present = None

# [batch_size, head_dim, q_length, k_length]
output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))

# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)

# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)

# Raw attention scores. [batch_size * num_heads, q_length, k_length]
beta = 1.0 / self.layer_number

matmul_result = torch.baddbmm(
alibi,
query_layer.transpose(1, 0),
key_layer.transpose(1, 0).transpose(1, 2),
beta=beta,
alpha=(1.0 / self.norm_factor),
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
return super().forward(
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
)

# change view to [batch_size, num_heads, q_length, k_length]
attention_scores = matmul_result.view(*output_size)

# attention scores and attention mask [b, np, sq, sk]
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
attention_probs = self.attention_dropout(attention_probs)

if head_mask is not None:
attention_probs = attention_probs * head_mask

# context layer shape: [batch_size, num_heads, q_length, head_dim]
output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))

# change view [k_length, batch_size x num_heads, head_dim]
value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)

# change view [batch_size x num_heads, q_length, k_length]
attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)

# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))

# change view [batch_size, num_heads, q_length, head_dim]
context_layer = context_layer.view(*output_size)

# [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

# [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)

context_layer = context_layer.view(*new_context_layer_shape)

# Output. [q_length, batch_size, hidden_size]

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
output_tensor = self.dense(context_layer)
output = output_tensor.transpose(1, 0)

output = dropout_add(output, residual, self.hidden_dropout, self.training)

outputs = (output, present)
if output_attentions:
outputs += (attention_probs,)

return outputs


class BloomMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
self.hidden_dropout = config.hidden_dropout
self.gelu_impl = BloomGelu()

def forward(self, hidden_states, residual):
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
intermediate_output = self.dense_4h_to_h(hidden_states)
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
return output


class BloomBlock(nn.Module):
def __init__(self, config, layer_number=None):
super().__init__()
self.hidden_size = config.hidden_size

self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
self.n_head = config.n_head
self.self_attention = BloomAttention(config, layer_number=layer_number)
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)

self.mlp = BloomMLP(config)

self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
self.hidden_dropout = config.hidden_dropout

def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=False,
output_attentions=False,
alibi=None,
):
# hidden_states: [batch_size, seq_length, hidden_size]

# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)

# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
def _prepare_attn_mask(
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape

if src_length > 1:
combined_attention_mask = _make_causal_mask(
torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length
)

# Self attention.
attn_outputs = self.self_attention(
layernorm_output,
residual,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)

attention_output = attn_outputs[0]

outputs = attn_outputs[1:]

layernorm_output = self.post_attention_layernorm(attention_output)

# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output

# MLP.
output = self.mlp(layernorm_output, residual)

if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]

return outputs # hidden_states, present, attentions
return combined_attention_mask
33 changes: 12 additions & 21 deletions src/petals/bloom/from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,17 @@
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from transformers.modeling_utils import WEIGHTS_NAME
from transformers.utils.hub import cached_path, hf_bucket_url
from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.utils import get_file_from_repo

from petals.bloom import BloomBlock, BloomConfig
from petals.bloom.block import WrappedBloomBlock
from petals.utils.disk_cache import DEFAULT_CACHE_DIR

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

CLIENT_BRANCH = "main"
BLOCK_BRANCH_PREFIX = "block_"
USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
FORCE_DOWNLOAD = False
RESUME_DOWNLOAD = False
LOCAL_FILES_ONLY = False


def load_pretrained_block(
Expand All @@ -36,15 +33,15 @@ def load_pretrained_block(
torch_dtype: Union[torch.dtype, str] = "auto",
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
) -> BloomBlock:
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
) -> WrappedBloomBlock:
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""

if config is None:
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR

block = BloomBlock(config, layer_number=block_index)
block = WrappedBloomBlock(config)
state_dict = _load_state_dict(
converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
)
Expand All @@ -70,20 +67,14 @@ def _load_state_dict(
cache_dir: Optional[str] = None,
) -> OrderedDict[str, torch.Tensor]:
revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)

# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=FORCE_DOWNLOAD,
proxies=None,
resume_download=RESUME_DOWNLOAD,
local_files_only=LOCAL_FILES_ONLY,
archive_file = get_file_from_repo(
pretrained_model_name_or_path,
filename=WEIGHTS_NAME,
revision=revision,
use_auth_token=use_auth_token,
user_agent=USER_AGENT,
cache_dir=cache_dir,
)
state_dict = torch.load(resolved_archive_file, map_location="cpu")
state_dict = torch.load(archive_file, map_location="cpu")
return state_dict


Expand Down
Loading