Skip to content

Commit 008cf88

Browse files
hbikkiHarsha Bikki
and
Harsha Bikki
authored
[Neuron] Adding support for adding/ overriding neuron configuration a… (vllm-project#8062)
Co-authored-by: Harsha Bikki <[email protected]>
1 parent 77d9e51 commit 008cf88

File tree

8 files changed

+243
-42
lines changed

8 files changed

+243
-42
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
3+
from vllm import LLM, SamplingParams
4+
5+
# creates XLA hlo graphs for all the context length buckets.
6+
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
7+
# creates XLA hlo graphs for all the token gen buckets.
8+
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
9+
# Quantizes neuron model weight to int8 ,
10+
# The default config for quantization is int8 dtype.
11+
os.environ['NEURON_QUANT_DTYPE'] = "s8"
12+
13+
# Sample prompts.
14+
prompts = [
15+
"Hello, my name is",
16+
"The president of the United States is",
17+
"The capital of France is",
18+
"The future of AI is",
19+
]
20+
# Create a sampling params object.
21+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
22+
23+
# Create an LLM.
24+
llm = LLM(
25+
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
26+
max_num_seqs=8,
27+
# The max_model_len and block_size arguments are required to be same as
28+
# max sequence length when targeting neuron device.
29+
# Currently, this is a known limitation in continuous batching support
30+
# in transformers-neuronx.
31+
# TODO(liangfu): Support paged-attention in transformers-neuronx.
32+
max_model_len=2048,
33+
block_size=2048,
34+
# The device can be automatically detected when AWS Neuron SDK is installed.
35+
# The device argument can be either unspecified for automated detection,
36+
# or explicitly assigned.
37+
device="neuron",
38+
quantization="neuron_quant",
39+
override_neuron_config={
40+
"cast_logits_dtype": "bfloat16",
41+
},
42+
tensor_parallel_size=2)
43+
# Generate texts from the prompts. The output is a list of RequestOutput objects
44+
# that contain the prompt, generated text, and other information.
45+
outputs = llm.generate(prompts, sampling_params)
46+
# Print the outputs.
47+
for output in outputs:
48+
prompt = output.prompt
49+
generated_text = output.outputs[0].text
50+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

vllm/config.py

+41-28
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import enum
22
import json
33
from dataclasses import dataclass, field, fields
4-
from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple,
5-
Type, Union)
4+
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
5+
Optional, Tuple, Type, Union)
66

77
import torch
88
from transformers import PretrainedConfig
@@ -115,35 +115,39 @@ class ModelConfig:
115115
the model name will be the same as `model`.
116116
limit_mm_per_prompt: Maximum number of data instances per modality
117117
per prompt. Only applicable for multimodal models.
118+
override_neuron_config: Initialize non default neuron config or
119+
override default neuron config that are specific to Neuron devices,
120+
this argument will be used to configure the neuron config that
121+
can not be gathered from the vllm arguments.
118122
"""
119123

120124
def __init__(
121-
self,
122-
model: str,
123-
tokenizer: str,
124-
tokenizer_mode: str,
125-
trust_remote_code: bool,
126-
dtype: Union[str, torch.dtype],
127-
seed: int,
128-
revision: Optional[str] = None,
129-
code_revision: Optional[str] = None,
130-
rope_scaling: Optional[dict] = None,
131-
rope_theta: Optional[float] = None,
132-
tokenizer_revision: Optional[str] = None,
133-
max_model_len: Optional[int] = None,
134-
spec_target_max_model_len: Optional[int] = None,
135-
quantization: Optional[str] = None,
136-
quantization_param_path: Optional[str] = None,
137-
enforce_eager: Optional[bool] = None,
138-
max_context_len_to_capture: Optional[int] = None,
139-
max_seq_len_to_capture: Optional[int] = None,
140-
max_logprobs: int = 20,
141-
disable_sliding_window: bool = False,
142-
skip_tokenizer_init: bool = False,
143-
served_model_name: Optional[Union[str, List[str]]] = None,
144-
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
145-
use_async_output_proc: bool = True,
146-
) -> None:
125+
self,
126+
model: str,
127+
tokenizer: str,
128+
tokenizer_mode: str,
129+
trust_remote_code: bool,
130+
dtype: Union[str, torch.dtype],
131+
seed: int,
132+
revision: Optional[str] = None,
133+
code_revision: Optional[str] = None,
134+
rope_scaling: Optional[dict] = None,
135+
rope_theta: Optional[float] = None,
136+
tokenizer_revision: Optional[str] = None,
137+
max_model_len: Optional[int] = None,
138+
spec_target_max_model_len: Optional[int] = None,
139+
quantization: Optional[str] = None,
140+
quantization_param_path: Optional[str] = None,
141+
enforce_eager: Optional[bool] = None,
142+
max_context_len_to_capture: Optional[int] = None,
143+
max_seq_len_to_capture: Optional[int] = None,
144+
max_logprobs: int = 20,
145+
disable_sliding_window: bool = False,
146+
skip_tokenizer_init: bool = False,
147+
served_model_name: Optional[Union[str, List[str]]] = None,
148+
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
149+
use_async_output_proc: bool = True,
150+
override_neuron_config: Optional[Dict[str, Any]] = None) -> None:
147151
self.model = model
148152
self.tokenizer = tokenizer
149153
self.tokenizer_mode = tokenizer_mode
@@ -227,6 +231,9 @@ def __init__(
227231
limit_mm_per_prompt)
228232
if not self.skip_tokenizer_init:
229233
self._verify_tokenizer_mode()
234+
235+
self.override_neuron_config = override_neuron_config if is_neuron(
236+
) else None
230237
self._verify_embedding_mode()
231238
self._verify_quantization()
232239
self._verify_cuda_graph()
@@ -275,6 +282,7 @@ def _verify_quantization(self) -> None:
275282
"experts_int8"
276283
]
277284
tpu_supported_quantization = ["tpu_int8"]
285+
neuron_supported_quantization = ["neuron_quant"]
278286
if self.quantization is not None:
279287
self.quantization = self.quantization.lower()
280288

@@ -329,6 +337,11 @@ def _verify_quantization(self) -> None:
329337
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
330338
" is not set, enabling VLLM_USE_TRITON_AWQ.")
331339
envs.VLLM_USE_TRITON_AWQ = True
340+
if is_neuron(
341+
) and self.quantization not in neuron_supported_quantization:
342+
raise ValueError(
343+
f"{self.quantization} quantization is currently not "
344+
f"supported in Neuron Backend.")
332345

333346
def _verify_cuda_graph(self) -> None:
334347
if self.max_seq_len_to_capture is None:

vllm/engine/arg_utils.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import dataclasses
33
import json
44
from dataclasses import dataclass
5-
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
6-
Union)
5+
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
6+
Type, Union)
77

88
import torch
99

@@ -149,6 +149,7 @@ class EngineArgs:
149149
otlp_traces_endpoint: Optional[str] = None
150150
collect_detailed_traces: Optional[str] = None
151151
disable_async_output_proc: bool = False
152+
override_neuron_config: Optional[Dict[str, Any]] = None
152153

153154
def __post_init__(self):
154155
if self.tokenizer is None:
@@ -742,6 +743,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
742743
default=EngineArgs.disable_async_output_proc,
743744
help="Disable async output processing. This may result in "
744745
"lower performance.")
746+
parser.add_argument(
747+
'--override-neuron-config',
748+
type=lambda configs: {
749+
str(key): value
750+
for key, value in
751+
(config.split(':') for config in configs.split(','))
752+
},
753+
default=None,
754+
help="override or set neuron device configuration.")
755+
745756
return parser
746757

747758
@classmethod
@@ -802,7 +813,7 @@ def create_engine_config(self) -> EngineConfig:
802813
served_model_name=self.served_model_name,
803814
limit_mm_per_prompt=self.limit_mm_per_prompt,
804815
use_async_output_proc=not self.disable_async_output_proc,
805-
)
816+
override_neuron_config=self.override_neuron_config)
806817
cache_config = CacheConfig(
807818
block_size=self.block_size if self.device != "neuron" else
808819
self.max_model_len, # neuron needs block_size = max_model_len

vllm/engine/llm_engine.py

+2
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def __init__(
214214
"Initializing an LLM engine (v%s) with config: "
215215
"model=%r, speculative_config=%r, tokenizer=%r, "
216216
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
217+
"override_neuron_config=%s, "
217218
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
218219
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
219220
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
@@ -232,6 +233,7 @@ def __init__(
232233
model_config.skip_tokenizer_init,
233234
model_config.tokenizer_mode,
234235
model_config.revision,
236+
model_config.override_neuron_config,
235237
model_config.rope_scaling,
236238
model_config.rope_theta,
237239
model_config.tokenizer_revision,

vllm/model_executor/layers/quantization/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
2323
GPTQMarlin24Config)
2424
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
25+
from vllm.model_executor.layers.quantization.neuron_quant import (
26+
NeuronQuantConfig)
2527
from vllm.model_executor.layers.quantization.qqq import QQQConfig
2628
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
2729
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
@@ -46,6 +48,7 @@
4648
"bitsandbytes": BitsAndBytesConfig,
4749
"qqq": QQQConfig,
4850
"experts_int8": ExpertsInt8Config,
51+
"neuron_quant": NeuronQuantConfig,
4952
}
5053

5154

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
from importlib.util import find_spec
3+
from typing import Any, Dict, List, Optional
4+
5+
from torch.nn import Module
6+
7+
from vllm.model_executor.layers.quantization.base_config import (
8+
QuantizationConfig)
9+
10+
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
11+
12+
13+
class NeuronQuantConfig(QuantizationConfig):
14+
"""Int8 Quantization Config class for Neuron Backend."""
15+
16+
def __init__(
17+
self,
18+
dequant_dtype: str = "f16",
19+
quantize_method: str = "vector_dynamic",
20+
) -> None:
21+
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
22+
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
23+
raise ValueError(
24+
f"Neuron quantization datatype {self.quant_dtype} is not valid,"
25+
f"the quantization datatype should match one of the below types"
26+
f"{SUPPORTED_QUANT_DTYPE_LIST}")
27+
self.dequant_dtype = dequant_dtype
28+
self.quantize_method = quantize_method
29+
30+
def get_name(self) -> str:
31+
return "neuron_quant"
32+
33+
def get_supported_act_dtypes(self) -> List[str]:
34+
return SUPPORTED_QUANT_DTYPE_LIST
35+
36+
@classmethod
37+
def get_min_capability(cls) -> int:
38+
raise NotImplementedError(
39+
"This function should not be called with Neuron Backend")
40+
41+
@staticmethod
42+
def get_config_filenames() -> List[str]:
43+
return []
44+
45+
@classmethod
46+
def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig":
47+
quantize_method = cls.get_from_keys(config, ["quantize_method"])
48+
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
49+
return cls(dequant_dtype=dequant_dtype,
50+
quantize_method=quantize_method)
51+
52+
def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
53+
if find_spec("transformers_neuronx") is not None:
54+
return self.get_quantization_config()
55+
else:
56+
raise NotImplementedError(
57+
"Neuron Quantization is only supported through"
58+
" transformers_neuronx.")
59+
60+
def get_scaled_act_names(self) -> List[str]:
61+
return []
62+
63+
def get_quantization_config(self):
64+
from transformers_neuronx.config import QuantizationConfig
65+
return QuantizationConfig(quant_dtype=self.quant_dtype,
66+
dequant_dtype=self.dequant_dtype,
67+
quantize_method=self.quantize_method)

0 commit comments

Comments
 (0)