Skip to content

Commit

Permalink
add some comment or debug info
Browse files Browse the repository at this point in the history
  • Loading branch information
charSLee committed Nov 17, 2023
1 parent a5f9a23 commit 0044c8a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
1 change: 1 addition & 0 deletions api/core/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
logging.debug(f"prompt: {prompt_messages}\nmodel_limited_tokens: {model_limited_tokens} \t max_tokens: {max_tokens}\tprompt_tokens: {prompt_tokens}\n")
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")

Expand Down
1 change: 1 addition & 0 deletions api/core/helper/moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
# 如果开启内容审查和开通了openai,则审核输入文本内容是否合规
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and model_provider.provider_name in hosted_config.moderation.providers:
Expand Down
5 changes: 4 additions & 1 deletion api/core/model_providers/models/embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def base_model_name(self) -> str:

@property
def price_config(self) -> dict:
"""返回货币计算的价格配置的属性。
"""
def get_or_default():
default_price_config = {
'completion': decimal.Decimal('0'),
Expand All @@ -45,7 +47,8 @@ def get_or_default():
}
return price_config

self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
if not hasattr(self, '_price_config'):
self._price_config = get_or_default()

logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config
Expand Down
31 changes: 27 additions & 4 deletions api/core/model_providers/models/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class BaseLLM(BaseProviderModel):
type: ModelType = ModelType.TEXT_GENERATION
deduct_quota: bool = True

# 负责初始化BaseLLM对象。它设置了模型名称、模型参数、是否流式处理等,并且根据是否流式处理初始化默认的回调处理器
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
Expand Down Expand Up @@ -76,6 +77,9 @@ def base_model_name(self) -> str:

@property
def price_config(self) -> dict:
"""
获取模型价格配置,如每个令牌的成本等,并默认为美元计价
"""
def get_or_default():
default_price_config = {
'prompt': decimal.Decimal('0'),
Expand All @@ -94,7 +98,9 @@ def get_or_default():
}
return price_config

self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
# 如果没有定义则初始化(懒汉)
if not hasattr(self, '_price_config'):
self._price_config = get_or_default()

logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config
Expand All @@ -104,31 +110,38 @@ def run(self, messages: List[PromptMessage],
callbacks: Callbacks = None,
**kwargs) -> LLMRunResult:
"""
根据输入的提示消息和停止词来执行预测,并处理模型运行过程中的逻辑,如内容审查、配额检查、回调处理等,并返回LLMRunResult对象
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
# 将所有输入的messages合并成一个字符串,并对其进行内容审查
moderation_result = moderation.check_moderation(
self.model_provider,
"\n".join([message.content for message in messages])
)

# 如果内容审查发现问题(moderation_result为None)
# 则在kwargs中设置fake_response,这是一个预设的安全回应
if not moderation_result:
kwargs['fake_response'] = "I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest."

# 配额检查,确保未超出使用配额
if self.deduct_quota:
self.model_provider.check_quota_over_limit()

# 设置回调
if not callbacks:
callbacks = self.callbacks
else:
callbacks.extend(self.callbacks)

if 'fake_response' in kwargs and kwargs['fake_response']:
# 如果kwargs中存在fake_response,则不会进行实际的预测调用,而是创建一个FakeLLM实例来生成假的预测结果
if 'fake_response' in kwargs:
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
response=kwargs['fake_response'],
Expand All @@ -139,6 +152,7 @@ def run(self, messages: List[PromptMessage],
result = fake_llm.generate([prompts])
else:
try:
# 调用LLM自身的实际运行逻辑 _run 来运行
result = self._run(
messages=messages,
stop=stop,
Expand All @@ -147,7 +161,8 @@ def run(self, messages: List[PromptMessage],
)
except Exception as ex:
raise self.handle_exceptions(ex)


# 解析_run方法返回的LLMResult对象,获取生成的内容
function_call = None
if isinstance(result.generations[0][0], ChatGeneration):
completion_content = result.generations[0][0].message.content
Expand All @@ -156,6 +171,7 @@ def run(self, messages: List[PromptMessage],
else:
completion_content = result.generations[0][0].text

# 如果设置为流式处理,但模型不支持流式处理,则会使用FakeLLM来模拟流式处理。
if self.streaming and not self.support_streaming:
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
Expand All @@ -167,21 +183,27 @@ def run(self, messages: List[PromptMessage],
)
fake_llm.generate([prompts])

# 计算令牌使用情况
# 如果结果中包含令牌使用信息,就使用这些信息来计算提示令牌数和完成令牌数。
if result.llm_output and result.llm_output['token_usage']:
prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
completion_tokens = result.llm_output['token_usage']['completion_tokens']
total_tokens = result.llm_output['token_usage']['total_tokens']
else:
# 如果没有这些信息,就调用get_num_tokens方法来计算提示和完成的令牌数。
prompt_tokens = self.get_num_tokens(messages)
completion_tokens = self.get_num_tokens(
[PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens

# 更新模型的最后使用时间
self.model_provider.update_last_used()


# 如果有设置扣减配额,则更新扣减后的配额
if self.deduct_quota:
self.model_provider.deduct_quota(total_tokens)

# 返回一个LLMRunResult对象,包含生成内容、提示令牌数、完成令牌数和可能的函数调用信息
return LLMRunResult(
content=completion_content,
prompt_tokens=prompt_tokens,
Expand Down Expand Up @@ -332,6 +354,7 @@ def _get_prompt_from_messages(self, messages: List[PromptMessage],

def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
"""
将模型参数转换为提供者模型的参数,考虑了规则中的别名、默认值、最小值和最大值等设置
convert model kwargs to provider model kwargs.
:param model_rules:
Expand Down

0 comments on commit 0044c8a

Please sign in to comment.