We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
使用huggging face官方给出的代码修改后
from transformers import AutoTokenizer, AutoModelForCausalLM import torch import torch_npu from torch.utils.data import Dataset, DataLoader import time from data_prepare import CPMDataset torch_npu.npu.set_compile_mode(jit_compile=False) torch.npu.empty_cache() trainset = CPMDataset("basic_task_finetune/bee_data/eval.jsonl") trainset = trainset[:100] train_loader = DataLoader(trainset, batch_size=2) model_path = "models/cpm-bee-2b" tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to('npu') optimizer = torch.optim.Adam(model.parameters()) for iter, data in enumerate(train_loader): model.train() step_start = time.perf_counter() optimizer.zero_grad() input_encoded = tokenizer.prepare_for_finetune(data, max_length=1024).to(model.device) outputs = model(**input_encoded) loss = outputs.loss loss.backward() optimizer.step() step_time = time.perf_counter() - step_start print(f"Step {iter}, Loss: {loss.item():.4f}, Time per step: {step_time:.4f} s")
输出的loss为NaN
数据处理为
import json from torch.utils.data import Dataset class CPMDataset(Dataset): def __init__(self, jsonl_file): self.data = [] with open(jsonl_file, 'r', encoding='utf-8') as file: for line in file: # 解析每一行 JSON 数据 item = json.loads(line) # 提取需要的字段 # input = item['input'] # options = item['options'] # question = item['question'] # answer = item['<ans>'] # input_text = f"{input}<sep>{question}<sep>{options}" # 将数据添加到列表中 self.data.append(item) def __len__(self): # 返回数据集的大小 return len(self.data) def __getitem__(self, idx): # 返回格式化的数据 return self.data[idx] if __name__=="__main__": dataset = CPMDataset('eval.jsonl')
The text was updated successfully, but these errors were encountered:
No branches or pull requests
使用huggging face官方给出的代码修改后
输出的loss为NaN
数据处理为
The text was updated successfully, but these errors were encountered: