Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6c9799a
Move
WoosukKwon May 6, 2023
bacf49c
http_frontend -> frontend
WoosukKwon May 6, 2023
a16c1f3
Move
WoosukKwon May 6, 2023
9b868d1
Move controller
WoosukKwon May 6, 2023
b2ff569
Minor
WoosukKwon May 7, 2023
38b946b
Fix import errors
WoosukKwon May 7, 2023
62c7175
Move controller back to worker
WoosukKwon May 7, 2023
26aafdc
Rename
WoosukKwon May 7, 2023
8419df6
mv
WoosukKwon May 7, 2023
d9520b4
Add __init__.py
WoosukKwon May 7, 2023
61bf05f
Minor
WoosukKwon May 7, 2023
799ce53
Move set_random_seeds
WoosukKwon May 7, 2023
b6f6d4c
Fix imports
WoosukKwon May 7, 2023
a25b37d
Extract out initialize_dummy_weights
WoosukKwon May 7, 2023
e2ea5cc
Minor
WoosukKwon May 7, 2023
de47b95
Minor
WoosukKwon May 7, 2023
724dc90
Fix import errors on parallel utils
WoosukKwon May 7, 2023
fd2647f
Add __init__.py
WoosukKwon May 7, 2023
7755a7a
Fix parallel_utils
WoosukKwon May 7, 2023
a95bd42
Minor
WoosukKwon May 7, 2023
4dc8e9e
Fix weight loading
WoosukKwon May 7, 2023
e6ffa80
Annotate types
WoosukKwon May 7, 2023
da591aa
Fix type
WoosukKwon May 7, 2023
0ae70da
sample -> sampler
WoosukKwon May 7, 2023
f1d2700
Minor
WoosukKwon May 7, 2023
338b2f4
Merge branch 'main' into refactor-arch
WoosukKwon May 7, 2023
af4c9c3
Enhance model loader
WoosukKwon May 8, 2023
f503a33
Merge branch 'main' into model-mapper
WoosukKwon May 9, 2023
ac481fc
Merge branch 'main' into model-mapper
WoosukKwon May 9, 2023
1dd7bfa
Minor
WoosukKwon May 9, 2023
8990887
Minor
WoosukKwon May 9, 2023
2d71ad5
Minor
WoosukKwon May 9, 2023
58ad6d3
single quote -> double quote
WoosukKwon May 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cacheflow/core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from cacheflow.frontend.simple_frontend import SimpleFrontend
from cacheflow.logger import init_logger
from cacheflow.model_executor import get_memory_analyzer
from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import get_gpu_memory, get_cpu_memory
from cacheflow.worker.controller import Controller, DeviceID

Expand Down
96 changes: 55 additions & 41 deletions cacheflow/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,51 @@
from cacheflow.model_executor.weight_utils import initialize_dummy_weights


_MODELS = {
'gpt2': GPT2LMHeadModel,
'llama': LlamaForCausalLM,
'opt': OPTForCausalLM,
'stablelm': GPTNeoXForCausalLM,
'pythia': GPTNeoXForCausalLM,
'dolly-v2': GPTNeoXForCausalLM,
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"GPT2LMHeadModel": GPT2LMHeadModel,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
}

_MEMORY_ANALYZERS = {
'gpt2': GPT2MemoryAnalyzer,
'llama': LlamaMemoryAnalyzer,
'opt': OPTMemoryAnalyzer,
'stablelm': GPTNeoXMemoryAnalyzer,
'pythia': GPTNeoXMemoryAnalyzer,
'dolly-v2': GPTNeoXMemoryAnalyzer,
"GPT2LMHeadModel": GPT2MemoryAnalyzer,
"GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer,
"LlamaForCausalLM": LlamaMemoryAnalyzer,
"OPTForCausalLM": OPTMemoryAnalyzer,
}


def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MODEL_REGISTRY:
return _MODEL_REGISTRY[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
)


def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MEMORY_ANALYZERS:
return _MEMORY_ANALYZERS[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}"
)


def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
# NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, 'torch_dtype', None)
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
if dtype == 'default':
if dtype == "default":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
Expand All @@ -51,7 +70,7 @@ def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise ValueError(
f'Cannot use {torch_dtype} for {config_dtype} model.')
f"Cannot use {torch_dtype} for {config_dtype} model.")
return torch_dtype


Expand All @@ -65,24 +84,21 @@ def get_model(
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
torch.set_default_dtype(torch_dtype)
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
if use_dummy_weights:
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(config)
model = model.cuda()
# NOTE(woosuk): For precise performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.
model.load_weights(model_name, cache_dir, use_np_cache)
model = model.cuda()
return model.eval(), torch_dtype
raise ValueError(f'Unsupported model name: {model_name}')
model_class = _get_model_architecture(config)

# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(config)
if use_dummy_weights:
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
model.load_weights(model_name, cache_dir, use_np_cache)
model = model.cuda()
return model.eval(), torch_dtype


def get_memory_analyzer(
Expand All @@ -95,9 +111,7 @@ def get_memory_analyzer(
) -> CacheFlowMemoryAnalyzer:
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name:
return memory_analyzer(
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
tensor_parallel_size)
raise ValueError(f'Unsupported model name: {model_name}')
memory_analyzer = _get_memory_analyzer(config)
return memory_analyzer(
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
tensor_parallel_size)