Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Dec 25, 2023
2 parents 9c16592 + 31c7630 commit 3abd6e1
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions modules/models/Qwen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
from transformers.generation import GenerationConfig
import logging
import colorama
Expand All @@ -9,8 +10,18 @@
class Qwen_Client(BaseLLMModel):
def __init__(self, model_name, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_METADATA[model_name]["repo_id"], trust_remote_code=True, resume_download=True)
self.model = AutoModelForCausalLM.from_pretrained(MODEL_METADATA[model_name]["repo_id"], device_map="auto", trust_remote_code=True, resume_download=True).eval()
model_source = None
if os.path.exists("models"):
model_dirs = os.listdir("models")
if model_name in model_dirs:
model_source = f"models/{model_name}"
if model_source is None:
try:
model_source = MODEL_METADATA[model_name]["repo_id"]
except KeyError:
model_source = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_source, trust_remote_code=True, resume_download=True)
self.model = AutoModelForCausalLM.from_pretrained(model_source, device_map="auto", trust_remote_code=True, resume_download=True).eval()

def generation_config(self):
return GenerationConfig.from_dict({
Expand Down

0 comments on commit 3abd6e1

Please sign in to comment.