本项目提供了一个使用 Hugging Face transformers 库和 PyTorch 实现的 BERT 模型微调(Fine-tuning)流程,用于中文文本的情感分类任务。
项目包含了完整的训练、测试和预测脚本,用户可以轻松地使用自己的数据来训练一个二分类(积极/消极)的情感分析模型。
- 基于 BERT: 使用强大的
bert-base-chinese预训练模型,只需少量数据即可达到良好的微调效果。 - 代码简洁: 代码结构清晰,分为
train.py,test.py,predict.py,职责分明。 - 易于使用: 提供了详细的步骤说明,从环境配置到模型预测,一步到位。
- 可扩展性: 用户可以方便地替换成其他预训练模型(如 RoBERTa, ERNIE 等)或修改模型参数。
bert-sentiment-classifier/
├── data/ # 存放数据文件
│ ├── train.csv # 训练数据
│ └── test.csv # 测试数据
├── saved_model/ # 存放训练好的模型文件
├── train.py # 训练模型的脚本
├── test.py # 评估模型性能的脚本
├── predict.py # 使用训练好的模型进行实时预测的脚本
├── requirements.txt # 项目依赖库
└── README.md # 项目说明文件
为了避免依赖冲突,强烈建议为本项目创建一个独立的 Conda 虚拟环境。
# 创建一个名为 bert_env 且使用 Python 3.9 的新环境
conda create --name bert_env python=3.9
# 激活新创建的环境
conda activate bert_env在已激活的虚拟环境中,我们使用一条命令来安装所有必需的库,并锁定 NumPy 的版本以避免已知的兼容性问题。
pip install torch transformers pandas scikit-learn "numpy<2.0"您也可以创建一个 requirements.txt 文件,然后通过 pip install -r requirements.txt 安装。
requirements.txt 文件内容 (修正后):
torch
transformers
pandas
scikit-learn
numpy<2.0
- 在项目根目录下创建一个名为
data的文件夹。 - 准备您的训练数据和测试数据,并将其保存为 CSV 格式。
- CSV 文件必须包含至少两列:
text(文本内容) 和label(情感标签)。- 积极情感:
label = 1 - 消极情感:
label = 0
- 积极情感:
data/train.csv 示例:
text,label
"这款手机的拍照效果真的惊艳到我了,电池续航也很给力。",1
"太失望了,产品描述和实物严重不符,申请退货了。",0
"客服非常有耐心,一步一步教我怎么安装,服务满分!",1
data/test.csv 示例:
text,label
"键盘手感很好,敲击声音清脆,非常适合游戏玩家。",1
"用了不到一个月,充电口就接触不良了,做工有问题。",0
重要提示: 在执行以下任何 python 命令之前,请确保您已经激活了 Conda 环境,并且已经使用 cd 命令进入了 bert-sentiment-classifier 项目的根目录。
# 示例:
# 1. 激活环境
conda activate bert_env
# 2. 进入项目目录
cd path/to/your/bert-sentiment-classifier在项目根目录下,运行 train.py 脚本来开始训练。
python train.py训练完成后,微调好的模型和分词器文件将自动保存在 saved_model 文件夹中。
训练结束后,在同一目录下运行 test.py 脚本来评估模型性能。
python test.py脚本会加载 saved_model 中的模型,并输出准确率 (Accuracy) 和一个详细的分类报告 (Classification Report)。
运行 predict.py 脚本来启动一个交互式预测程序。
python predict.py交互示例:
加载完成!
请输入一句中文进行情感分析 (输入 '退出' 或 'quit' 来结束): 这家餐厅的味道真是太棒了,下次一定还来!
- 文本: '这家餐厅的味道真是太棒了,下次一定还来!'
- 预测情感: 积极
- 置信度: 0.9987
请输入一句中文进行情感分析 (输入 '退出' 或 'quit' 来结束): 退出
您可以在 train.py 和 test.py 脚本的开头部分,根据自己的需求和硬件条件调整以下超参数:
PRE_TRAINED_MODEL_NAME: 更换其他预训练模型,如'hfl/chinese-roberta-wwm-ext'。MAX_LEN: 句子的最大处理长度。BATCH_SIZE: 每个批次的数据量,受 GPU 显存大小影响。EPOCHS: 训练轮数。LEARNING_RATE: 学习率。