forked from AGenCyLab/SPCUP2022
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_svm.py
133 lines (102 loc) · 4.12 KB
/
train_svm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from datetime import datetime
import os
from argparse import ArgumentParser
import pathlib
import numpy as np
from sklearn.linear_model import SGDClassifier
from tqdm import tqdm
from utils.config import load_config_file
from datasets.SPCUP22DataModule import SPCUP22DataModule
from features.audio import MFCC
from torchvision.transforms import Compose
import pickle
def build_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-save-path", type=str, default="./checkpoints/svm"
)
parser.add_argument(
"--include-unseen-data-in-training", action="store_true", default=False
)
parser.add_argument(
"--include-augmented-data", action="store_true", default=False
)
return parser
def save_checkpoint(classifier, filepath, model_params=None):
"""Saves the checkpoint"""
# add the model hyperparameters to the classifier object
if model_params:
classifier.model_params = model_params
with open(filepath, "wb") as model_file_obj:
pickle.dump(classifier, model_file_obj)
if __name__ == "__main__":
current_timestamp = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
parser = build_parser()
args = parser.parse_args()
os.makedirs(args.checkpoint_save_path, exist_ok=True)
# config and params
train_config = load_config_file("./config/train_params.yaml")["svm"]
batch_size = train_config["training"]["batch_size"]
epochs = train_config["training"]["epochs"]
n_mfcc = train_config["features"]["n_mfcc"]
hop_length = train_config["features"]["hop_length"]
# feature
mfcc_extractor = MFCC(n_mfcc=n_mfcc, hop_length=hop_length)
transforms = Compose([mfcc_extractor])
# datamodule
data_module = SPCUP22DataModule(
batch_size=batch_size,
dataset_root=str(pathlib.Path("./data/raw_audio/spcup22").absolute()),
transform=transforms,
should_include_unseen_in_training_data=args.include_unseen_data_in_training,
should_include_augmented_data=args.include_augmented_data,
)
data_module.prepare_data()
data_module.setup()
classes = np.array(list(range(data_module.num_classes)), dtype=int)
# svm
classifier = SGDClassifier(**train_config["params"])
# callbacks
# others ...
validation_error = float("inf")
# fit
for epoch in tqdm(range(epochs)):
train_data = data_module.train_dataloader()
val_data = data_module.val_dataloader()
for batch in train_data:
samples, labels, _ = batch
samples = samples.detach().numpy()
labels = labels.detach().numpy()
samples = np.reshape(samples, newshape=(batch_size, -1))
classifier.partial_fit(samples, labels, classes=classes)
current_val_error = 0
num_val_batches = 0
for batch in val_data:
samples, labels, _ = batch
samples = samples.detach().numpy()
labels = labels.detach().numpy()
samples = np.reshape(samples, newshape=(batch_size, -1))
accuracy = classifier.score(samples, labels)
error = 1 - accuracy
current_val_error += error
num_val_batches += 1
current_val_error /= num_val_batches
print("Validation Error: {:.2f}".format(current_val_error))
# only save the checkpoint which has less validation error
if current_val_error < validation_error:
validation_error = current_val_error
model_filename = "svm-{}-{:.2f}.pkl".format(
current_timestamp, current_val_error
)
model_path = str(
pathlib.Path(args.checkpoint_save_path).joinpath(
model_filename
)
)
save_checkpoint(classifier, model_path, train_config["params"])
# save the last checkpoint after all epochs are completed
model_filename = "svm-{}-{:.2f}.pkl".format("last", current_val_error)
model_path = str(
pathlib.Path(args.checkpoint_save_path).joinpath(model_filename)
)
save_checkpoint(classifier, model_path, train_config["params"])