-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_character_bin.py
executable file
·76 lines (55 loc) · 2.72 KB
/
run_character_bin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import sys
import os
import pickle
import numpy as np
import pandas as pd
from scipy import stats
import re
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.model_selection import train_test_split
def run_character_bin(repr_type, train_prefix, val_prefix):
with open(train_prefix+'_all_embeddings.pkl', 'rb') as file:
all_embeddings = pickle.load(file)
with open(train_prefix+'_all_concept_pixels.pkl', 'rb') as file:
all_concept_pixels = pickle.load(file)
with open(val_prefix+'_all_embeddings.pkl', 'rb') as file:
val_all_embeddings = pickle.load(file)
with open(val_prefix+'_all_concept_pixels.pkl', 'rb') as file:
val_all_concept_pixels = pickle.load(file)
all_embeddings_swav = np.concatenate(all_embeddings)
val_all_embeddings_swav = np.concatenate(val_all_embeddings)
bin_mapping = np.array([ 2194., 10191., 18188., 26185., 34182., 42179., 50176.])
y_train = [y_ex-1 for y_ex in np.digitize(all_concept_pixels, bin_mapping)]
y_test = [y_ex-1 for y_ex in np.digitize(val_all_concept_pixels, bin_mapping)]
x_train = all_embeddings_swav
x_test = val_all_embeddings_swav
log_reg = LogisticRegression(class_weight='balanced', max_iter=500, multi_class='ovr', n_jobs=80)
log_reg.fit(x_train, y_train)
pkl_filename = "probing_results/log_reg_char_bin_"+repr_type+'.pkl'
with open(pkl_filename, 'wb') as file:
pickle.dump(log_reg, file)
y_pred = log_reg.predict(x_test)
y_pred_proba = log_reg.predict_proba(x_test)
results = pd.DataFrame.from_dict(classification_report(y_test, y_pred,
output_dict=True)).round(2)
results.to_csv('probing_results/'+repr_type+'_char_bin.csv')
roc_auc = roc_auc_score(y_test, y_pred_proba, multi_class='ovr')
roc_auc_ovo = roc_auc_score(y_test, y_pred_proba, multi_class='ovo')
with open('probing_results/'+repr_type+'_char_bin_roc_auc.txt', 'w') as file:
file.write(str(roc_auc))
file.close()
with open('probing_results/'+repr_type+'_char_bin_roc_auc_ovo.txt', 'w') as file:
file.write(str(roc_auc_ovo))
file.close()
if __name__ == "__main__":
print('SWAV')
run_character_bin('swav', 'train_swav', 'val_swav')
print('MOCO')
run_character_bin('moco', 'train_superpixels_moco', 'val_superpixels_moco')
print('BYOL')
run_character_bin('byol', 'train_superpixels_byol', 'val_superpixels_byol')
print('SIMCLR')
run_character_bin('simclr', 'train_superpixels_simclr', 'val_superpixels_simclr')