Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请问关于生成带标点符号的数据进行微调的问题 #101

Open
leon1208 opened this issue Dec 18, 2024 · 2 comments
Open

请问关于生成带标点符号的数据进行微调的问题 #101

leon1208 opened this issue Dec 18, 2024 · 2 comments

Comments

@leon1208
Copy link

我用项目的aishell.py,生成了带标点符号的数据集

使用下面的命令进行训练
python finetune.py --base_model=openai/whisper-medium --output_dir=output/ --per_device_train_batch_size=4 --per_device_eval_batch_size=4

训练快结束时的日志,我感觉已经过拟合了,但是应该是正常结束的:
{'loss': 0.0402, 'grad_norm': 0.7124148011207581, 'learning_rate': 7.552000635122261e-06, 'epoch': 2.98} {'loss': 0.0426, 'grad_norm': 1.0721967220306396, 'learning_rate': 6.559622102254684e-06, 'epoch': 2.98} {'loss': 0.0394, 'grad_norm': 2.0902748107910156, 'learning_rate': 5.567243569387107e-06, 'epoch': 2.98} {'loss': 0.0453, 'grad_norm': 0.31820055842399597, 'learning_rate': 4.57486503651953e-06, 'epoch': 2.99} {'loss': 0.0495, 'grad_norm': 1.1558741331100464, 'learning_rate': 3.582486503651953e-06, 'epoch': 2.99} {'loss': 0.044, 'grad_norm': 0.3208547532558441, 'learning_rate': 2.590107970784376e-06, 'epoch': 2.99} {'loss': 0.0443, 'grad_norm': 3.8056881427764893, 'learning_rate': 1.597729437916799e-06, 'epoch': 3.0} {'loss': 0.0379, 'grad_norm': 1.798136830329895, 'learning_rate': 6.053509050492219e-07, 'epoch': 3.0} 100%|████████████████████████████████████████████████████████████████████████████| 100818/100818 [35:41:10<00:00, 1.08it/s]效果最好的检查点为:output/whisper-medium/checkpoint-100000,评估结果为:0.10838370025157928 {'train_runtime': 128473.0379, 'train_samples_per_second': 3.139, 'train_steps_per_second': 0.785, 'train_loss': 0.08579817058079518, 'epoch': 3.0}

然后合并模型,有警告
`
root@abda881e1afb:/workspace# python merge_lora.py --lora_model=output/whisper-medium/checkpoint-best/ --output_dir=models/
----------- Configuration Arguments -----------
lora_model: output/whisper-medium/checkpoint-best/
output_dir: models/
local_files_only: False

/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py:2817: UserWarning: Moving the following attributes in the config to the generation config: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}. You are seeing this warning because you've set generation parameters in the model config, as opposed to in the generation config.
warnings.warn(
合并模型保持在:models/whisper-medium-finetune
`

对模型进行评估,效果和项目给出的结果差距有点大,cer只有0.9
`
root@abda881e1afb:/workspace# python evaluation.py --model_path=models/whisper-medium-finetune/ --metric=cer
----------- Configuration Arguments -----------
test_data: dataset/test.json
model_path: models/whisper-medium-finetune/
batch_size: 16
num_workers: 8
language: Chinese
remove_pun: True
to_simple: True
timestamps: False
min_audio_len: 0.5
max_audio_len: 30
local_files_only: True
task: transcribe
metric: cer

读取数据列表: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7176/7176 [00:00<00:00, 123452.45it/s]
测试数据:7176
0%| | 0/449 [00:00<?, ?it/s]You have passed language=chinese, but also have set forced_decoder_ids to [[1, None], [2, 50359]] which creates a conflict. forced_decoder_ids will be ignored in favor of language=chinese.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 449/449 [07:53<00:00, 1.05s/it]
评估结果:cer=0.98695
`

最关键的问题是,我最后调用infer进行推理的时候,不输出任何信息。我甚至用我生成的训练数据做推理,也不输出任何信息。
`
root@abda881e1afb:/workspace# python infer.py --audio_path=dataset/CL-SH-1202014116_1.wav --model_path=models/whisper-medium-finetune/
----------- Configuration Arguments -----------
audio_path: dataset/CL-SH-1202014116_1.wav
model_path: models/whisper-medium-finetune/
use_gpu: True
language: chinese
num_beams: 1
batch_size: 16
use_compile: False
task: transcribe
assistant_model_path: None
local_files_only: True
use_flash_attention_2: False
use_bettertransformer: False

Device set to use cuda:0
/opt/conda/lib/python3.10/site-packages/transformers/pipelines/automatic_speech_recognition.py:312: FutureWarning: max_new_tokens is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass max_new_tokens as a key inside generate_kwargs instead.
warnings.warn(
/opt/conda/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:512: FutureWarning: The input name inputs is deprecated. Please make sure to use input_features instead.
warnings.warn(
You have passed task=transcribe, but also have set forced_decoder_ids to [[1, None], [2, 50359]] which creates a conflict. forced_decoder_ids will be ignored in favor of task=transcribe.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results.
`

大佬能不能帮我分析下,哪一个环节有问题。

@yeyupiaoling
Copy link
Owner

@leon1208 你有改动过token文件吗?

@leon1208
Copy link
Author

大佬,没动过token啊。我很奇怪,所以我后来又试了一次,先做一个epoch,然后测试了一下cer,发现比较接近说明里给出的结果,推理输出也正常。然后再做一个epoch,也正常的。最后做了第三个epoch,结果又都一切正常。cer也几乎和项目给出的预期结果一致,推理也是正常的。

不知道被坑在哪里了。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants