|
| 1 | +from typing import Tuple, List, Iterator |
| 2 | +from os.path import exists as path_exists |
| 3 | +from itertools import chain |
| 4 | +from functools import partial |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch.nn.functional as F |
| 8 | + |
| 9 | +from transformers import TextStreamer |
| 10 | + |
| 11 | +from Cluster.InfernBatchedWorker import InfernBatchedWorker |
| 12 | +from Cluster.InfernTTSWorker import get_torch_hw |
| 13 | +from Cluster.LLMSession import LLMResult, LLMInferRequest |
| 14 | + |
| 15 | +class ResultsStreamer(TextStreamer): |
| 16 | + debug = False |
| 17 | + sync_on = ('. ', '? ', '! ', '\n') |
| 18 | + decode_batch_size = 8 |
| 19 | + def __init__(self, wis:List[LLMInferRequest], upper:'InfernLLMWorker'): |
| 20 | + super().__init__(tokenizer=upper.llm_tokenizer) |
| 21 | + self.wi_cbs = tuple(wi.textout_cb for wi in wis) |
| 22 | + self.newLLMResult = tuple(partial(LLMResult, req_id=wi.req.id) for wi in wis) |
| 23 | + batch_size = len(wis) |
| 24 | + self.oposs = [0 for _ in range(batch_size)] |
| 25 | + self.current_tokens = None |
| 26 | + self.batch_decode = partial(upper.llm_tokenizer.batch_decode, skip_special_tokens=True) |
| 27 | + |
| 28 | + def put(self, token_ids): |
| 29 | + if self.current_tokens is None: |
| 30 | + self.current_tokens = torch.zeros((token_ids.shape[0], 0), dtype=torch.long) |
| 31 | + return |
| 32 | + if token_ids.dim() == 1: # Shape [batch_size] |
| 33 | + token_ids = token_ids.unsqueeze(1) |
| 34 | + self.current_tokens = torch.cat([self.current_tokens, token_ids], dim=1) |
| 35 | + if self.current_tokens.shape[1] % self.decode_batch_size == 0: |
| 36 | + return |
| 37 | + results = self.batch_decode(self.current_tokens) |
| 38 | + for (ir, r), op, cb, newLR in zip(enumerate(results), self.oposs, self.wi_cbs, self.newLLMResult): |
| 39 | + new_content = r[op:] |
| 40 | + if len(new_content) == 0: continue |
| 41 | + sp = (op + pos + len(c) for c in self.sync_on if (pos:=new_content.rfind(c)) >= 0) |
| 42 | + try: |
| 43 | + spos = next(sp) |
| 44 | + except StopIteration: |
| 45 | + continue |
| 46 | + r = r[op:spos-1] |
| 47 | + if len(r) < 10: continue |
| 48 | + cb(result=newLR(r)) |
| 49 | + self.oposs[ir] = spos |
| 50 | + if self.debug: |
| 51 | + print(f'{self.oposs=} {self.current_tokens.shape=}') |
| 52 | + |
| 53 | + def end(self): |
| 54 | + if self.debug: |
| 55 | + print(f'finished: {self.current_tokens.shape=}') |
| 56 | + results = self.batch_decode(self.current_tokens) |
| 57 | + for r, op, cb, newLR in zip(results, self.oposs, self.wi_cbs, self.newLLMResult): |
| 58 | + if len(r) == op: continue |
| 59 | + cb(result=newLR(r[op:])) |
| 60 | + del self.current_tokens |
| 61 | + del self.wi_cbs |
| 62 | + |
| 63 | +class InfernLLMWorker(InfernBatchedWorker): |
| 64 | + model_name = "Qwen/Qwen2.5-14B-Instruct" |
| 65 | + model_cache_dir = f"/tmp/saved_model.{model_name}" |
| 66 | + max_batch_size: int = 8 |
| 67 | + debug = True |
| 68 | + llm_model: object |
| 69 | + llm_tokenizer: object |
| 70 | + output_sr: int |
| 71 | + |
| 72 | + def __init__(self, device=None): |
| 73 | + from warnings import filterwarnings |
| 74 | + filterwarnings("ignore", category=FutureWarning) |
| 75 | + filterwarnings("ignore", category=UserWarning) |
| 76 | + from transformers import AutoTokenizer |
| 77 | + from ipex_llm.transformers import AutoModelForCausalLM |
| 78 | + super().__init__() |
| 79 | + if device is None: |
| 80 | + device = get_torch_hw() |
| 81 | + def load_model(mn): |
| 82 | + m = AutoModelForCausalLM.from_pretrained(mn, torch_dtype="auto", |
| 83 | + device_map="auto", |
| 84 | + optimize_model=True, |
| 85 | + trust_remote_code=True, |
| 86 | + load_in_4bit=True, |
| 87 | + use_cache=True |
| 88 | + ) |
| 89 | + if mn != self.model_cache_dir: |
| 90 | + m.save_low_bit(self.model_cache_dir) |
| 91 | + return m.to(device) |
| 92 | + if path_exists(self.model_cache_dir): |
| 93 | + try: |
| 94 | + model = AutoModelForCausalLM.load_low_bit(self.model_cache_dir, |
| 95 | + trust_remote_code=True) |
| 96 | + except Exception: |
| 97 | + model = load_model(self.model_name) |
| 98 | + else: |
| 99 | + model = load_model(self.model_name) |
| 100 | + self.llm_model = model.to(device) |
| 101 | + self.llm_tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
| 102 | + |
| 103 | + def process_batch(self, wis:List[LLMInferRequest]): |
| 104 | + if self.debug: |
| 105 | + print(f'InfernLLMWorker.process_batch: got {len(wis)=}') |
| 106 | + streamer = ResultsStreamer(wis, self) |
| 107 | + with torch.no_grad(): |
| 108 | + messages = [self.llm_tokenizer.apply_chat_template(list(r.context), tokenize=False, |
| 109 | + add_generation_prompt=True) |
| 110 | + for r in wis] |
| 111 | + model_inputs = self.llm_tokenizer(messages, return_tensors="pt", padding=True).to(self.llm_model.device) |
| 112 | + self.llm_model.generate( |
| 113 | + **model_inputs, |
| 114 | + max_new_tokens=16 * 1024, |
| 115 | + output_scores=True, |
| 116 | + return_dict_in_generate=True, |
| 117 | + streamer=streamer, |
| 118 | + ) |
| 119 | + torch.xpu.synchronize() |
0 commit comments