Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ profile = "black"
line_length = 120
combine_as_imports = true
combine_star = true
known_local_folder = ["tests", "cli"]
known_local_folder = ["tests", "cli"]
known_first_party = ["test_utils"]
26 changes: 16 additions & 10 deletions src/petals/bloom/from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from typing import Optional, OrderedDict, Union

import torch
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from hivemind.utils.logging import get_logger
from transformers.modeling_utils import WEIGHTS_NAME
from transformers.models.bloom.configuration_bloom import BloomConfig
Expand All @@ -38,13 +40,16 @@ def load_pretrained_block(
max_disk_space: Optional[int] = None,
) -> WrappedBloomBlock:
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"

if config is None:
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR

block = WrappedBloomBlock(config)
with init_empty_weights():
block = WrappedBloomBlock(config)

state_dict = _load_state_dict(
converted_model_name_or_path,
block_index,
Expand All @@ -54,16 +59,17 @@ def load_pretrained_block(
max_disk_space=max_disk_space,
)

if torch_dtype == "auto":
with torch.no_grad():
for name, param in block.named_parameters():
assert name in state_dict, f"{name} not in state dict"
param.data = param.data.to(state_dict[name].dtype)
else:
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
block = block.to(dtype=torch_dtype)

# dummy load, check that keys match
report = block.load_state_dict(state_dict, strict=True)
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"

for param_name, _ in block.named_parameters():
assert param_name in state_dict, f"{param_name} not in state dict"
param = state_dict[param_name]
if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
param = param.to(torch_dtype)
set_module_tensor_to_device(block, param_name, "cpu", value=param)

logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
return block

Expand Down
2 changes: 1 addition & 1 deletion src/petals/server/block_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_block_size(
dtype is not None and load_in_8bit is not None
), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'

with init_empty_weights():
with init_empty_weights(include_buffers=True):
block = WrappedBloomBlock(config)
n_params = sum(param.numel() for param in block.parameters())

Expand Down
4 changes: 2 additions & 2 deletions tests/test_aux_functions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest
import torch
from test_utils import MODEL_NAME

from petals.client import DistributedBloomConfig
from petals.server.throughput import measure_compute_rps, measure_network_rps
from petals.server.throughput import measure_compute_rps
from test_utils import MODEL_NAME


@pytest.mark.forked
Expand Down
51 changes: 49 additions & 2 deletions tests/test_block_exact_match.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import random
from typing import Union

import hivemind
import pytest
import torch
from test_utils import *
from transformers.models.bloom.configuration_bloom import BloomConfig

from petals.bloom.from_pretrained import load_pretrained_block
from petals.bloom.block import WrappedBloomBlock
from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
from petals.client import DistributedBloomConfig
from petals.client.remote_sequential import RemoteTransformerBlock
from petals.data_structures import UID_DELIMITER
from petals.dht_utils import get_remote_module
from test_utils import *


@pytest.mark.forked
Expand Down Expand Up @@ -41,3 +44,47 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):

assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)


def _old_load_pretrained_block(
converted_model_name_or_path: str,
block_index: int,
torch_dtype: Union[torch.dtype, str] = "auto",
) -> WrappedBloomBlock:
"""Load the BLOOM block by directly initializing the weights.
This test is used to check consistency with the previous implementation and can be removed in the future."""
config = BloomConfig.from_pretrained(converted_model_name_or_path)

block = WrappedBloomBlock(config)
state_dict = _load_state_dict(
converted_model_name_or_path,
block_index,
config,
cache_dir=None,
)

if torch_dtype == "auto":
with torch.no_grad():
for name, param in block.named_parameters():
assert name in state_dict, f"{name} not in state dict"
param.data = param.data.to(state_dict[name].dtype)
else:
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
block = block.to(dtype=torch_dtype)

block.load_state_dict(state_dict, strict=True)
return block


@pytest.mark.forked
def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
torch.random.manual_seed(0)
inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype)

block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)

outputs = block.forward(inputs)[0]
outputs_ref = ref_block.forward(inputs)[0]
assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward)
2 changes: 1 addition & 1 deletion tests/test_chained_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import hivemind
import pytest
import torch
from test_utils import *

from petals.bloom.from_pretrained import load_pretrained_block
from petals.client import DistributedBloomConfig
from petals.client.remote_sequential import RemoteSequential
from petals.dht_utils import get_remote_sequence
from test_utils import *


@pytest.mark.forked
Expand Down
2 changes: 1 addition & 1 deletion tests/test_full_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import torch
import transformers
from hivemind import get_logger
from test_utils import *
from transformers.generation import BeamSearchScorer
from transformers.models.bloom import BloomForCausalLM

from petals.client.remote_model import DistributedBloomForCausalLM
from test_utils import *

logger = get_logger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_remote_sequential.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import pytest
import torch
import torch.nn.functional as F
from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler
from hivemind import DHT, BatchTensorDescriptor, get_logger
from hivemind.proto import runtime_pb2
from test_utils import *

from petals.bloom.from_pretrained import load_pretrained_block
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
from test_utils import *

logger = get_logger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_sequence_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import pytest
import torch
from hivemind import DHT, get_logger
from test_utils import *

from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
from test_utils import *

logger = get_logger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_server_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import hivemind
import pytest
import torch
from test_utils import *

from petals.client import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
from petals.dht_utils import get_remote_sequence
from petals.server.handler import CACHE_TOKENS_AVAILABLE
from test_utils import *


@pytest.mark.forked
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import transformers
from tensor_parallel import TensorParallel
from tensor_parallel.slicing_configs import get_bloom_config
from test_utils import MODEL_NAME

from petals.bloom.from_pretrained import load_pretrained_block
from test_utils import MODEL_NAME


@pytest.mark.forked
Expand Down