forked from UKPLab/emnlp2018-april
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstage0_sample_summaries.py
84 lines (57 loc) · 2.24 KB
/
stage0_sample_summaries.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from summariser.vector.vector_generator import Vectoriser
from summariser.utils.corpus_reader import CorpusReader
from resources import *
from summariser.utils.writer import append_to_file
import sys
def writeSample(actions,reward,path):
if 'heuristic' in path:
str = '\nactions:'
for act in actions:
str += repr(act)+','
str = str[:-1]
str += '\nutility:'+repr(reward)
append_to_file(str, path)
else:
assert 'rouge' in path
str = '\n'
for j,model_name in enumerate(reward):
str += '\nmodel {}:{}'.format(j,model_name)
str += '\nactions:'
for act in actions:
str += repr(act)+','
str = str[:-1]
str += '\nR1:{};R2:{};R3:{};R4:{};RL:{};RSU:{}'.format(
reward[model_name][0],reward[model_name][1],reward[model_name][2],
reward[model_name][3],reward[model_name][4],reward[model_name][5]
)
append_to_file(str, path)
if __name__ == '__main__':
if len(sys.argv) == 4:
dataset = sys.argv[1]
start = int(sys.argv[2])
end = int(sys.argv[3])
else:
dataset = 'DUC2001' #DUC2001, DUC2002, DUC2004
start = 0
end = 9999
language = 'english'
summary_len = 100
summary_num = 10001
base_dir = os.path.join(SUMMARY_DB_DIR,dataset)
reader = CorpusReader(PROCESSED_PATH)
data = reader.get_data(dataset,summary_len)
topic_cnt = 0
for topic, docs, models in data:
topic_cnt += 1
if not(topic_cnt > start and topic_cnt <= end):
continue
dir_path = os.path.join(base_dir,topic)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
vec = Vectoriser(docs,summary_len)
print('-----Generate samples for topic {}: {}-----'.format(topic_cnt,topic))
act_list, h_rewards, r_rewards = vec.sampleRandomReviews(summary_num,True,True,models)
assert len(act_list) == len(h_rewards) == len(r_rewards)
for ii in range(len(act_list)):
writeSample(act_list[ii],h_rewards[ii],os.path.join(dir_path,'heuristic'))
writeSample(act_list[ii],r_rewards[ii],os.path.join(dir_path,'rouge'))