diff --git a/modules/models/LLaMA.py b/modules/models/LLaMA.py index eca013ed..e7c9a2b4 100644 --- a/modules/models/LLaMA.py +++ b/modules/models/LLaMA.py @@ -88,6 +88,16 @@ def _get_llama_style_input(self): else: context.append(conv["content"] + OUTPUT_POSTFIX) return "".join(context) + # for conv in self.history: + # if conv["role"] == "system": + # context.append(conv["content"]) + # elif conv["role"] == "user": + # context.append( + # conv["content"] + # ) + # else: + # context.append(conv["content"]) + # return "\n\n".join(context)+"\n\n" def get_answer_at_once(self): context = self._get_llama_style_input() @@ -105,7 +115,7 @@ def get_answer_stream_iter(self): iter = self.model( context, max_tokens=self.max_generation_token, - stop=[], + stop=[SYS_PREFIX, SYS_POSTFIX, INST_PREFIX, OUTPUT_PREFIX,OUTPUT_POSTFIX], echo=False, stream=True, ) diff --git a/modules/models/base_model.py b/modules/models/base_model.py index 155e246b..2fb46bf9 100644 --- a/modules/models/base_model.py +++ b/modules/models/base_model.py @@ -176,7 +176,7 @@ def get_type(cls, model_name: str): elif "星火大模型" in model_name_lower: model_type = ModelType.Spark else: - model_type = ModelType.Unknown + model_type = ModelType.LLaMA return model_type