中文文本分类,Bert, RoBerta,ernie,albert,reformer 基于pytorch,开箱即用,支持单标签文本分类和多标签文本分类
目前数据来源于https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch
类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。
数据集划分
数据集 | 数据量 |
---|---|
训练集 | 18万 |
验证集 | 1万 |
测试集 | 1万 |
train.csv示例, 这里class为数字,对应中文标签为data/xinwen/class.txt的行数
text | class |
---|---|
中华女子学院:本科层次仅1专业招男生 | 3 |
两天价网站背后重重迷雾:做个网站究竟要多少钱 | 4 |
东5环海棠公社230-290平2居准现房98折优惠 | 1 |
... | ... |
数据来源:2020语言与智能技术竞赛:事件抽取任务
列表元素对应class.txt中的行数也就是标签个数,每个值,1则为包含该标签,0则不包含该标签
train.csv
label | content | class |
---|---|---|
['人生-结婚'] | 张杰幽默回应说错与谢娜结婚年份:必须上搓衣板 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
['竞赛行为-胜负', '竞赛行为-晋级'] | 拒绝爆冷!王雅繁击败世界第184,韩国赛晋级第二轮 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
['产品行为-发布'] | 华为正式发布鸿蒙OS,有四大技术特性 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
['人生-婚礼'] | 唐艺昕婚礼现场激动落泪,与妈妈告别难掩悲伤,马思纯哭成泪人 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
bert_classification_raw/
├── data # 放数据
│ ├── xinwen #放数据
│ │ ├── {model}_acc.png # 训练时产生, 训练集、验证集正确率曲线
│ │ ├── {model}_loss.png # 训练时产生, loss下降曲线
│ │ ├── test_data_predict.csv # 行test.py得到的对test.csv的预测结果
│ │ ├── class.txt # 行号代表标签类别
│ │ ├── train.csv # 训练数据
│ │ ├── test.csv # 测试数据
│ │ └── dev.csv # 验证数据
│ └─ models # 训练中会产生模型存放在这里
├── config.py # 配置文件
├── api.py # # 通过flask开启http服务, 5000端口 swagger页面, 需使用infer.py
├── dataset.py # pytorch 文本数据整理成bert输入dataloader
├── infer.py # 训练完成后,运行这个可以在终端输入文本在线测试,并且这个文件不依赖其他文件,如果只用在预测,可以拷走这个文件和模型文件在里面配一下模型文件地址即可,目前配置是bert的
├── model.py # 放写模型的代码
├── test.py # 对test.csv文件的整体测试结果,增加一列预测结果,保存到csv文件,文件位置在config.py中配置
├── train.py # 训练
├── run.sh # 训练脚本,循环调用train.py 得到ernie、bert、bert_wwm、roberta、ernie_healthy, reformer等结果
└── utils.py # 小的工具函数
1、在data下建立和xinwen一样的文件夹(包含train.csv,test.csv,dev.csv,class.txt)
2、python train.py --model bert --dir_name xinwen --epochs 20 --batch_size 64
model | acc | loss |
---|---|---|
bert | ||
bert_wwm | ||
roberta | ||
ernie | ||
ernie_healthy | ||
albert | ||
reformer |
model | acc | loss |
---|---|---|
bert | ||
bert_wwm | ||
roberta | ||
ernie | ||
ernie_healthy | ||
albert | ||
reformer |
# 训练
python train.py # 会生成data/xinwen/acc.png,data/xinwen/loss.png,正确率和loss曲线
# 终端直接输入文本测试
python infer.py
# test.csv测试
python test.py # 会生成data/xinwen/test_pred_out
https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch