Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen-7B-Chat模型按照Quto-GPTQ示例进行4bit量化,报错:ValueError: Pointer argument (at 2) cannot be accessed from Triton (cpu tensor?)[BUG] <title> #646

Closed
2 tasks done
sunyclj opened this issue Nov 17, 2023 · 16 comments

Comments

@sunyclj
Copy link

sunyclj commented Nov 17, 2023

是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this?

  • 我已经搜索过已有的issues和讨论 | I have searched the existing issues / discussions

该问题是否在FAQ中有解答? | Is there an existing answer for this in FAQ?

  • 我已经搜索过FAQ | I have searched FAQ

当前行为 | Current Behavior

运行到model.quantize(examples),报错:ValueError: Pointer argument (at 2) cannot be accessed from Triton (cpu tensor?),请问会是什么原因?

期望行为 | Expected Behavior

No response

复现方法 | Steps To Reproduce

No response

运行环境 | Environment

- OS:
- Python:
- Transformers:
- PyTorch:
- CUDA (`python -c 'import torch; print(torch.version.cuda)'`):

备注 | Anything else?

No response

@lonngxiang
Copy link

同报错;内容
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@jklj077
Copy link
Contributor

jklj077 commented Nov 21, 2023

您模型可能加载到内存里了,如下方案供参考

  1. 强制所有的参数加载到显存中(不要用device_map='auto',直接指定到'cuda:0'之类的,或者不设置,加载后再转移模型到GPU)
  2. 卸载flash-attn

@lonngxiang
Copy link

您模型可能加载到内存里了,如下方案供参考

  1. 强制所有的参数加载到显存中(不要用device_map='auto',直接指定到'cuda:0'之类的,或者不设置,加载后再转移模型到GPU)
  2. 卸载flash-attn

这边测试单卡cuda:0,但因为显卡不够所以放弃了;这能多显张卡运行吗

@lonngxiang
Copy link

如果我只用cpu量化,同意报错

AttributeError: 'QWenLMHeadModel' object has no attribute 'quantize'

@jklj077
Copy link
Contributor

jklj077 commented Nov 21, 2023

@lonngxiang 还是要用GPU的,但我们没试过多卡量化。您多卡加载后(device_map='auto'),如果显存足够,应该不会有在内存里的参数(如果有的话,可以打印下model.hf_device_map,看看哪些到内存上了)。

@lonngxiang
Copy link

@lonngxiang 还是要用GPU的,但我们没试过多卡量化。您多卡加载后(device_map='auto'),如果显存足够,应该不会有在内存里的参数(如果有的话,可以打印下model.hf_device_map,看看哪些到内存上了)。

任然报错
image

代码:

from transformers import AutoTokenizer, TextGenerationPipeline
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import logging

import torch
device=torch.device("cpu")
#device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
logging.basicConfig(
    format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)

pretrained_model_dir = "/mnt/data/loong/Qwen-7B-Chat"
quantized_model_dir = "/mnt/data/loong/Qwen-7B-Chat-4bit"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True, trust_remote_code=True)
examples = [
    tokenizer(
        "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm.", return_tensors="pt").to(device)
]

quantize_config = BaseQuantizeConfig(
    bits=4,  # quantize model to 4-bit
    group_size=128,  # it is recommended to set the value to 128
    desc_act=False,  # set to False can significantly speed up inference but the perplexity may slightly bad
)

# load un-quantized model, by default, the model will always be loaded into CPU memory
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, trust_remote_code=True, low_cpu_mem_usage=True, device_ma
p='auto')

# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(examples)

# save quantized model
model.save_quantized(quantized_model_dir)

# save quantized model using safetensors
model.save_quantized(quantized_model_dir, use_safetensors=True)

@sunyclj
Copy link
Author

sunyclj commented Nov 23, 2023

quantize_config

按照提供的两种解决方案,我这也是一样的报错,“RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!”

@sunyclj
Copy link
Author

sunyclj commented Nov 23, 2023

您模型可能加载到内存里了,如下方案供参考

  1. 强制所有的参数加载到显存中(不要用device_map='auto',直接指定到'cuda:0'之类的,或者不设置,加载后再转移模型到GPU)
  2. 卸载flash-attn

model.hf_device_map

加载模型设置device_map="cuda:1",为什么报错依然是在0卡和cpu上呢?“RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!”

@MiyazonoKaori
Copy link

MiyazonoKaori commented Nov 26, 2023

修改modeling_qwen.py
`
def apply_rotary_pos_emb(t, freqs):
cos, sin = freqs
cos = cos.to(t.device)
sin = sin.to(t.device)
if apply_rotary_emb_func is not None and t.is_cuda:
t_ = t.float()
cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
return output
else:
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
cos = cos.to(t.device)
sin = sin.to(t.device)
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
t_ = t_.float()
t_pass_ = t_pass_.float()
t_ = (t_ * cos) + (rotate_half(t) * sin)
return torch.cat((t_, t_pass_), dim=-1).type_as(t)

`

@WingsLong
Copy link

也要的问题,自己量化训练后的模型报错 了
Traceback (most recent call last):
File "/data/aigc/train_models/autogpt_quantize.py", line 36, in
model.quantize(examples)
File "/data/programs/python310/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/data/programs/python310/lib/python3.10/site-packages/auto_gptq/modeling/_base.py", line 359, in quantize
layer(layer_input, **additional_layer_inputs)
File "/data/programs/python310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/programs/python310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/Qwen-1_8B-Chat_novel1207/modeling_qwen.py", line 610, in forward
attn_outputs = self.attn(
File "/data/programs/python310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/programs/python310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/Qwen-1_8B-Chat_novel1207/modeling_qwen.py", line 432, in forward
query = apply_rotary_pos_emb(query, q_pos_emb)
File "/root/.cache/huggingface/modules/transformers_modules/Qwen-1_8B-Chat_novel1207/modeling_qwen.py", line 1345, in apply_rotary_pos_emb
t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@lvheyang
Copy link

在autogptq 的官方示例下面

...
model = AutoGPTQForCausalLM.from_pretrained(...)

# 增加如下内容,强制模型转移到gpu
if torch.cuda.is_available():
    model.cuda()

# 以下省略

测试了这个方法,对我有效果

@sunyclj
Copy link
Author

sunyclj commented Dec 14, 2023

apply_rotary_pos_emb

改完可以量化了,加载量化后的权重推理又出现问题了,FileNotFoundError: Could not find model in ./lora_finetune/qwen_7b_chat_q,应该是少文件

@sunyclj
Copy link
Author

sunyclj commented Dec 14, 2023

在autogptq 的官方示例下面

...
model = AutoGPTQForCausalLM.from_pretrained(...)

# 增加如下内容,强制模型转移到gpu
if torch.cuda.is_available():
    model.cuda()

# 以下省略

测试了这个方法,对我有效果

请问量化之后,加载推理正常吗?我这边应该是缺少模型文件,报错:FileNotFoundError: Could not find model in ./lora_finetune/qwen_7b_chat_q
文件结构如下:
1702551319443_9A07E291-6935-4148-9B9D-0FD0A223034D

@sunyclj
Copy link
Author

sunyclj commented Dec 19, 2023

在autogptq 的官方示例下面

...
model = AutoGPTQForCausalLM.from_pretrained(...)

# 增加如下内容,强制模型转移到gpu
if torch.cuda.is_available():
    model.cuda()

# 以下省略

测试了这个方法,对我有效果

请问量化之后,加载推理正常吗?我这边应该是缺少模型文件,报错:FileNotFoundError: Could not find model in ./lora_finetune/qwen_7b_chat_q 文件结构如下: 1702551319443_9A07E291-6935-4148-9B9D-0FD0A223034D

已解决,可量化并推理,但是推理输出效果低于官方开源的int4量化权重,暂未分析到原因

@jklj077
Copy link
Contributor

jklj077 commented Dec 21, 2023

apply_rotary_emb这里报错,应该是AutoGPTQ会自己在device迁移tensor,但实现覆盖的不全,导致有些tensor没被迁移。参见AutoGPTQ/AutoGPTQ#370 (comment)

但是推理输出效果低于官方开源的int4量化权重

参考以下回复哈

校准用的数据影响不能忽略的,需要跟应用场景同分布,GPTQ需要根据校准集最小化量化误差。

@jklj077 jklj077 closed this as completed Dec 21, 2023
@qazzombie
Copy link

在autogptq 的官方示例下面

...
model = AutoGPTQForCausalLM.from_pretrained(...)

# 增加如下内容,强制模型转移到gpu
if torch.cuda.is_available():
    model.cuda()

# 以下省略

测试了这个方法,对我有效果

请问量化之后,加载推理正常吗?我这边应该是缺少模型文件,报错:FileNotFoundError: Could not find model in ./lora_finetune/qwen_7b_chat_q 文件结构如下: 1702551319443_9A07E291-6935-4148-9B9D-0FD0A223034D

你好,请问你这个少文件的问题怎么解决的呀,我的量化完之后也是没有模型文件

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants