Skip to content

Commit

Permalink
Merge pull request #365 from yangheng95/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
yangheng95 authored Nov 29, 2023
2 parents e06a0fd + f5f8d05 commit 07a251f
Show file tree
Hide file tree
Showing 34 changed files with 786 additions and 247 deletions.
100 changes: 14 additions & 86 deletions examples-v2/aspect_polarity_classification/ensemble_inference.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,12 @@
# -*- coding: utf-8 -*-
# file: inference.py
# time: 05/11/2022 19:48
# author: YANG, HENG <[email protected]> (杨恒)
# github: https://github.com/yangheng95
# GScholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# ResearchGate: https://www.researchgate.net/profile/Heng-Yang-17/research
# Copyright (C) 2022. All Rights Reserved.
import random

import findfile
import tqdm
from sklearn import metrics
from sklearn.metrics import classification_report

from pyabsa import AspectPolarityClassification as APC

from pyabsa import (
AspectPolarityClassification as APC,
ModelSaveOption,
DeviceTypeOption,
)
import warnings

from pyabsa.tasks.AspectPolarityClassification import APCDatasetList
from pyabsa.utils import VoteEnsemblePredictor
from pyabsa.utils.pyabsa_utils import fprint, rprint
import warnings

warnings.filterwarnings("ignore")

Expand All @@ -51,53 +33,24 @@ def ensemble_performance(dataset, print_result=False):
result = ensemble_predict(apc_classifiers, text, print_result)
pred.append(result)
gold.append(text.split("$LABEL$")[-1].strip())
fprint(classification_report(gold, pred, digits=4))
print(classification_report(gold, pred, digits=4))


if __name__ == "__main__":
# Training the models before ensemble inference, take Laptop14 as an example

# for dataset in [
# APCDatasetList.Laptop14,
# # APCDatasetList.Restaurant14,
# # APCDatasetList.Restaurant15,
# # APCDatasetList.Restaurant16,
# # APCDatasetList.MAMS
# ]:
# for model in [
# APC.APCModelList.FAST_LSA_T_V2,
# APC.APCModelList.FAST_LSA_S_V2,
# # APC.APCModelList.BERT_SPC_V2 # BERT_SPC_V2 is slow in ensemble inference so we don't use it
# ]:
# config = APC.APCConfigManager.get_apc_config_english()
# config.model = model
# config.pretrained_bert = 'microsoft/deberta-v3-base'
# config.evaluate_begin = 5
# config.max_seq_len = 80
# config.num_epoch = 30
# config.log_step = 10
# config.patience = 10
# config.dropout = 0
# config.cache_dataset = False
# config.l2reg = 1e-8
# config.lsa = True
# config.seed = [random.randint(0, 10000) for _ in range(3)]
#
# APC.APCTrainer(config=config,
# dataset=dataset,
# checkpoint_save_mode=ModelSaveOption.SAVE_MODEL_STATE_DICT,
# auto_device=DeviceTypeOption.AUTO,
# ).destroy()

for dataset in [Laptop14, Restaurant14, Restaurant15, Restaurant16, MAMS]:
# Training
pass
# Ensemble inference
dataset_file_dict = {
# 'laptop14': findfile.find_cwd_files(['laptop14', '.inference'], exclude_key=[]),
"laptop14": "integrated_datasets/apc_datasets/110.SemEval/113.laptop14/Laptops_Test_Gold.xml.seg.inference",
"restaurant14": "integrated_datasets/apc_datasets/110.SemEval/114.restaurant14/Restaurants_Test_Gold.xml.seg.inference",
"restaurant15": "integrated_datasets/apc_datasets/110.SemEval/115.restaurant15/restaurant_test.raw.inference",
"restaurant16": "integrated_datasets/apc_datasets/110.SemEval/116.restaurant16/restaurant_test.raw.inference",
"twitter": "integrated_datasets/apc_datasets/120.Twitter/120.twitter/twitter_test.raw.inference",
"mams": "integrated_datasets/apc_datasets/109.MAMS/test.xml.dat.inference",
"laptop14": "Laptops_Test_Gold.xml.seg.inference",
"restaurant14": "Restaurants_Test_Gold.xml.seg.inference",
"restaurant15": "restaurant_test.raw.inference",
"restaurant16": "restaurant_test.raw.inference",
"twitter": "twitter_test.raw.inference",
"mams": "test.xml.dat.inference",
}

checkpoints = {
Expand All @@ -112,7 +65,7 @@ def ensemble_performance(dataset, print_result=False):
for key, files in dataset_file_dict.items():
text_classifiers = {}

fprint(f"Ensemble inference")
print(f"Ensemble inference")
lines = []
if isinstance(files, str):
files = [files]
Expand Down Expand Up @@ -141,30 +94,5 @@ def ensemble_performance(dataset, print_result=False):
accuracy = count1 / (i + 1)
it.set_description(f"Accuracy: {accuracy:.4f}")

rprint(metrics.classification_report(batch_gold, batch_pred, digits=4))
fprint(f"Final accuracy: {accuracy}")

while True:
text = input("Please input your text sequence: ")
if text == "exit":
break
if text == "":
continue
_, _, true_label = text.partition("$LABEL$")
try:
result = ensemble_predictor.predict(
text, ignore_error=False, print_result=False
)
print(result)
pred_label = result["label"]
confidence = result["confidence"]
fprint(
"Predicted Label:",
pred_label,
"Reference Label: ",
true_label,
"Correct: ",
pred_label == true_label,
)
except Exception as e:
fprint(e)
print(metrics.classification_report(batch_gold, batch_pred, digits=4))
print(f"Final accuracy: {accuracy}")
153 changes: 153 additions & 0 deletions examples-v2/aspect_polarity_classification/lsa_e_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
# file: inference.py
# time: 05/11/2022 19:48
# author: YANG, HENG <[email protected]> (杨恒)
# github: https://github.com/yangheng95
# GScholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# ResearchGate: https://www.researchgate.net/profile/Heng-Yang-17/research
# Copyright (C) 2022. All Rights Reserved.
import random

import findfile
import tqdm
from sklearn import metrics
from sklearn.metrics import classification_report

from pyabsa import AspectPolarityClassification as APC

from pyabsa import (
AspectPolarityClassification as APC,
ModelSaveOption,
DeviceTypeOption,
)
import warnings

from pyabsa.tasks.AspectPolarityClassification import APCDatasetList
from pyabsa.utils import VoteEnsemblePredictor
from pyabsa.utils.pyabsa_utils import fprint, rprint

warnings.filterwarnings("ignore")


if __name__ == "__main__":
# Training the models before ensemble inference, take Laptop14 as an example

# for dataset in [
# APCDatasetList.Laptop14,
# # APCDatasetList.Restaurant14,
# # APCDatasetList.Restaurant15,
# # APCDatasetList.Restaurant16,
# # APCDatasetList.MAMS
# ]:
# for model in [
# APC.APCModelList.FAST_LSA_T_V2,
# APC.APCModelList.FAST_LSA_S_V2,
# # APC.APCModelList.BERT_SPC_V2 # BERT_SPC_V2 is slow in ensemble inference so we don't use it
# ]:
# config = APC.APCConfigManager.get_apc_config_english()
# config.model = model
# config.pretrained_bert = 'microsoft/deberta-v3-base'
# config.evaluate_begin = 5
# config.max_seq_len = 80
# config.num_epoch = 30
# config.log_step = 10
# config.patience = 10
# config.dropout = 0
# config.cache_dataset = False
# config.l2reg = 1e-8
# config.lsa = True
# config.seed = [random.randint(0, 10000) for _ in range(3)]
#
# APC.APCTrainer(config=config,
# dataset=dataset,
# checkpoint_save_mode=ModelSaveOption.SAVE_MODEL_STATE_DICT,
# auto_device=DeviceTypeOption.AUTO,
# ).destroy()
# Ensemble inference
dataset_file_dict = {
# 'laptop14': findfile.find_cwd_files(['laptop14', '.inference'], exclude_key=[]),
"laptop14": "integrated_datasets/apc_datasets/110.SemEval/113.laptop14/Laptops_Test_Gold.xml.seg.inference",
"restaurant14": "integrated_datasets/apc_datasets/110.SemEval/114.restaurant14/Restaurants_Test_Gold.xml.seg.inference",
"restaurant15": "integrated_datasets/apc_datasets/110.SemEval/115.restaurant15/restaurant_test.raw.inference",
"restaurant16": "integrated_datasets/apc_datasets/110.SemEval/116.restaurant16/restaurant_test.raw.inference",
"twitter": "integrated_datasets/apc_datasets/120.Twitter/120.twitter/twitter_test.raw.inference",
"mams": "integrated_datasets/apc_datasets/109.MAMS/test.xml.dat.inference",
}
for model_name in ["bert-base-uncased"]:
for dataset in [
"laptop14",
"restaurant14",
"restaurant15",
"restaurant16",
"mams",
]:
if len(findfile.find_cwd_dirs(key=[f"{dataset}_acc", model_name])) == 0:
rprint(f"No checkpoints found for {dataset} {model_name}")
continue

checkpoints = {
ckpt: APC.SentimentClassifier(checkpoint=ckpt)
for ckpt in findfile.find_cwd_dirs(key=[f"{dataset}_acc", model_name])
}

ensemble_predictor = VoteEnsemblePredictor(
checkpoints, weights=None, numeric_agg="mean", str_agg="max_vote"
)

files = dataset_file_dict[dataset]
text_classifiers = {}

lines = []
if isinstance(files, str):
files = [files]
for file in files:
with open(file, "r") as f:
lines.extend(f.readlines())

# 测试总体准确率 batch predict
# eval acc
count1 = 0
accuracy = 0
batch_pred = []
batch_gold = []

# do not merge the same sentence
results = ensemble_predictor.batch_predict(
lines, ignore_error=False, print_result=False
)
it = tqdm.tqdm(results, ncols=100)
for i, result in enumerate(it):
label = result["sentiment"]
if label == lines[i].split("$LABEL$")[-1].strip():
count1 += 1
batch_pred.append(label)
batch_gold.append(lines[i].split("$LABEL$")[-1].strip().split(","))
accuracy = count1 / (i + 1)
it.set_description(f"Accuracy: {accuracy:.4f}")

fprint(f"{model_name} {dataset} Accuracy: {accuracy:.4f}")

# while True:
# text = input("Please input your text sequence: ")
# if text == "exit":
# break
# if text == "":
# continue
# _, _, true_label = text.partition("$LABEL$")
# try:
# result = ensemble_predictor.predict(
# text, ignore_error=False, print_result=False
# )
# print(result)
# pred_label = result["label"]
# confidence = result["confidence"]
# fprint(
# "Predicted Label:",
# pred_label,
# "Reference Label: ",
# true_label,
# "Correct: ",
# pred_label == true_label,
# )
# except Exception as e:
# fprint(e)
23 changes: 8 additions & 15 deletions examples-v2/aspect_polarity_classification/train_apc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,13 @@
APC.APCModelList.BERT_SPC_V2,
]

datasets = DatasetItem(
[
APC.APCDatasetList.Laptop14,
APC.APCDatasetList.Restaurant14,
APC.APCDatasetList.Restaurant15,
APC.APCDatasetList.Restaurant16,
APC.APCDatasetList.MAMS,
]
)

for dataset in [
APC.APCDatasetList.Laptop14,
APC.APCDatasetList.Restaurant14,
APC.APCDatasetList.Restaurant15,
APC.APCDatasetList.Restaurant16,
# APCDatasetList.MAMS
APC.APCDatasetList.MAMS,
]:
for model in [
APC.APCModelList.FAST_LSA_T_V2,
Expand All @@ -47,6 +38,7 @@
]:
for pretrained_bert in [
"microsoft/deberta-v3-base",
# "bert-base-uncased",
# 'roberta-base',
# 'microsoft/deberta-v3-large',
]:
Expand All @@ -55,19 +47,19 @@
config.pretrained_bert = pretrained_bert
# config.pretrained_bert = 'roberta-base'
config.evaluate_begin = 0
config.max_seq_len = 80
config.max_seq_len = 512
config.num_epoch = 30
# config.log_step = 5
config.log_step = -1
config.patience = 1
config.patience = 999
config.dropout = 0.5
config.eta = -1
config.eta = 1
config.eta_lr = 0.001
# config.lcf = 'fusion'
config.cache_dataset = False
config.l2reg = 1e-8
config.learning_rate = 1e-5
config.use_amp = True
config.learning_rate = 2e-5
config.use_amp = False
config.use_bert_spc = True
config.lsa = True
config.use_torch_compile = False
Expand All @@ -79,6 +71,7 @@
# from_checkpoint='english',
checkpoint_save_mode=ModelSaveOption.SAVE_MODEL_STATE_DICT,
# checkpoint_save_mode=ModelSaveOption.DO_NOT_SAVE_MODEL,
path_to_save=f"checkpoints/{pretrained_bert}",
auto_device=DeviceTypeOption.AUTO,
)
trainer.load_trained_model()
10 changes: 6 additions & 4 deletions examples-v2/aspect_polarity_classification/train_apc_plm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,26 @@
# Copyright (C) 2021. All Rights Reserved.
import random

from pyabsa import AspectPolarityClassification, ModelSaveOption

########################################################################################################################
# train and evaluate on your own apc_datasets (need train and test apc_datasets) #
########################################################################################################################

from pyabsa import AspectPolarityClassification, ModelSaveOption

config = AspectPolarityClassification.APCConfigManager.get_apc_config_english()
config.evaluate_begin = 0
config.num_epoch = 1
config.max_seq_len = 80
config.max_seq_len = 160
config.log_step = -1
config.dropout = 0
config.l2reg = 1e-5
config.cache_dataset = False
config.seed = random.randint(0, 10000)
config.model = AspectPolarityClassification.BERTBaselineAPCModelList.ASGCN_BERT
# configuration_class.spacy_model = 'zh_core_web_sm'
# chinese_sets = ABSADatasetList.Chinese
chinese_sets = AspectPolarityClassification.APCDatasetList.Laptop14
chinese_sets = AspectPolarityClassification.APCDatasetList.ARTS_Laptop14
# chinese_sets = AspectPolarityClassification.APCDatasetList.ARTS_Restaurant14
# chinese_sets = ABSADatasetList.MOOC
sent_classifier = AspectPolarityClassification.APCTrainer(
config=config,
Expand Down
Loading

0 comments on commit 07a251f

Please sign in to comment.