Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"YiForCausalLM": ("yi", "YiForCausalLM"),
Expand Down
120 changes: 59 additions & 61 deletions vllm/model_executor/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,6 @@
KVCache = Tuple[torch.Tensor, torch.Tensor]


class PhiEmbedding(nn.Module):

def __init__(self, config: PretrainedConfig):
super().__init__()

self.wte = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)

def forward(self, input_ids: torch.LongTensor):
return self.wte(input_ids)


class PhiAttention(nn.Module):

def __init__(self,
Expand All @@ -93,27 +79,22 @@ def __init__(self,
tensor_model_parallel_world_size)

# pylint: disable=C0103
self.Wqkv = QKVParallelLinear(
self.hidden_size,
self.head_size,
self.total_num_heads,
linear_method=linear_method,
)
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.hidden_size,
self.head_size,
self.total_num_heads,
bias=False,
bias=True,
linear_method=linear_method,
)
self.out_proj = RowParallelLinear(
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
linear_method=linear_method,
)

scaling = self.head_size**-0.5
rotary_dim = config.rotary_dim
rotary_dim = int(config.partial_rotary_factor *
(config.hidden_size // config.num_attention_heads))
assert rotary_dim % 2 == 0

# pylint: disable=C0301
Expand All @@ -136,12 +117,12 @@ def forward(
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.out_proj(attn_output)
output, _ = self.dense(attn_output)
return output


Expand All @@ -166,8 +147,7 @@ def __init__(self,
linear_method=linear_method,
)
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
n_inner)
self.act = get_act_fn(config.hidden_act, quant_config, n_inner)

def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states)
Expand All @@ -182,9 +162,9 @@ def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.mixer = PhiAttention(config, linear_method)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, linear_method)
self.mlp = PhiMLP(config, linear_method)

def forward(
Expand All @@ -195,8 +175,8 @@ def forward(
input_metadata: InputMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln(hidden_states)
attn_outputs = self.mixer(
hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
Expand All @@ -215,11 +195,14 @@ def __init__(self,
super().__init__()
self.config = config
self.linear_method = linear_method
self.embd = PhiEmbedding(config)
self.h = nn.ModuleList([
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
PhiLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)

def forward(
self,
Expand All @@ -228,27 +211,19 @@ def forward(
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embd(input_ids)
hidden_states = self.embed_tokens(input_ids)
for i in range(self.config.num_hidden_layers):
layer = self.h[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
)
return hidden_states


class PhiCausalLMHead(nn.Module):
hidden_states = self.final_layernorm(hidden_states)

def __init__(self, config: PretrainedConfig):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.linear = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
return hidden_states


class PhiForCausalLM(nn.Module):
Expand All @@ -260,8 +235,11 @@ def __init__(self,
self.config = config
self.linear_method = linear_method

self.transformer = PhiModel(config, linear_method)
self.lm_head = PhiCausalLMHead(config)
self.model = PhiModel(config, linear_method)

self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
self.sampler = Sampler(config.vocab_size)

def forward(
Expand All @@ -271,17 +249,17 @@ def forward(
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata)
hidden_states = self.lm_head.ln(hidden_states)
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)

return hidden_states

def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
head = self.lm_head.linear
head = self.lm_head
next_tokens = self.sampler(head.weight, hidden_states,
sampling_metadata, head.bias)
return next_tokens
Expand All @@ -291,17 +269,37 @@ def load_weights(self,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v")
]
params_dict = dict(self.named_parameters())

for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# pylint: disable=E1136
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# pylint: disable=E1136

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)