Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fastdeploy/demo/offline_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM

model_name_or_path = "./models/llama-7b"
model_name_or_path = "/home/aistudio/config_folder"

# 超参设置
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
llm = LLM(model=model_name_or_path, tensor_parallel_size=1)
llm = LLM(model=model_name_or_path, tensor_parallel_size=4, load_choices="default_v1")
output = llm.generate(prompts="who are you?", use_tqdm=True, sampling_params=sampling_params)

print(output)
26 changes: 26 additions & 0 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,32 @@ def start(self, api_server_pid=None):
self._init_worker_signals()

self.data_processor = self.input_processor.create_processor()

# [CORE FIX] Check and set pad_token_id if it's missing.
# Some tokenizers do not have a pad_token_id, which causes issues with padding.
# We use the eos_token_id as a robust fallback in such cases.
if self.data_processor.pad_token_id is None:
eos_token_id = self.data_processor.tokenizer.eos_token_id
if eos_token_id is not None:
console_logger.warning(
f"Tokenizer's pad_token_id is None. Setting it to the eos_token_id ({eos_token_id}) for padding."
)
# 1. Update the tokenizer instance directly. This is crucial for padding operations.
self.data_processor.tokenizer.pad_token_id = eos_token_id

# 2. Update the data_processor's attribute to ensure workers are initialized with the correct value.
self.data_processor.pad_token_id = eos_token_id

# 3. Update the main model configuration to maintain a single source of truth.
if self.cfg.model_config.pad_token_id is None:
self.cfg.model_config.pad_token_id = eos_token_id
else:
# This is a critical failure case. A model must have a token for padding.
raise ValueError(
"Tokenizer has neither a pad_token_id nor an eos_token_id. "
"Cannot proceed without a token for padding."
)

self.engine.data_processor = self.data_processor
# Launch components: scheduler, cache_manager, expert_service et.al.
self.launch_components()
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ class ForwardMeta:
block_tables: Optional[paddle.Tensor] = None
# KV caches
caches: Optional[list[paddle.Tensor]] = None

# Linear attention caches
linear_attn_caches: Optional[paddle.Tensor] = None
# Slot mapping
slot_mapping: Optional[paddle.Tensor] = None

def clear_caches(self):
"""Safely clean up the caches"""
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def get_rope_impl(
if model_config is None or architecture.startswith("Qwen"):
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
elif architecture.startswith("Glm"):
elif "MiniMaxM1" in architecture or architecture.startswith("Glm"):
rotary_emb_layer = GlmRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
else:
Expand Down
Loading
Loading