-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathlaunch_test.py
103 lines (84 loc) · 4.49 KB
/
launch_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/python
import pickle, getopt, sys, time, re
import datetime, os;
import scipy.io;
import nltk;
import numpy;
import optparse;
def parse_args():
parser = optparse.OptionParser()
parser.set_defaults(# parameter set 1
input_directory=None,
model_directory=None,
snapshot_index=-1,
)
# parameter set 1
parser.add_option("--input_directory", type="string", dest="input_directory",
help="input directory [None]");
parser.add_option("--model_directory", type="string", dest="model_directory",
help="model directory [None]");
parser.add_option("--snapshot_index", type="int", dest="snapshot_index",
help="snapshot index [-: evaluate on all available snapshots]");
(options, args) = parser.parse_args();
return options;
def main():
options = parse_args();
# parameter set 1
# assert(options.input_corpus_name!=None);
assert(options.input_directory != None);
assert(options.model_directory != None);
input_directory = options.input_directory;
input_directory = input_directory.rstrip("/");
input_corpus_name = os.path.basename(input_directory);
model_directory = options.model_directory;
model_directory = model_directory.rstrip("/");
if not os.path.exists(model_directory):
sys.stderr.write("error: model directory %s does not exist...\n" % (os.path.abspath(model_directory)));
return;
corpus_directory = os.path.split(os.path.abspath(model_directory))[0];
model_corpus_name = os.path.split(os.path.abspath(corpus_directory))[1]
if input_corpus_name != model_corpus_name:
sys.stderr.write("error: corpus name does not match for input (%s) and model (%s)...\n" % (input_corpus_name, model_corpus_name));
return;
snapshot_index = options.snapshot_index;
print("========== ========== ========== ========== ==========")
# parameter set 1
print("model_directory=" + model_directory)
print("input_directory=" + input_directory)
print("corpus_name=" + input_corpus_name)
print("snapshot_index=" + str(snapshot_index));
print("========== ========== ========== ========== ==========")
# Document
test_docs_path = os.path.join(input_directory, 'test.dat')
input_doc_stream = open(test_docs_path, 'r');
test_docs = [];
for line in input_doc_stream:
test_docs.append(line.strip().lower());
print("successfully load all testing docs from %s..." % (os.path.abspath(test_docs_path)));
if snapshot_index >= 0:
input_snapshot_path = os.path.join(model_directory, ("model-%d" % (snapshot_index)))
if not os.path.exists(input_snapshot_path):
sys.stderr.write("error: model snapshot %s does not exist...\n" % (os.path.abspath(input_snapshot_path)));
return;
output_lambda_path = os.path.join(model_directory, "test-lambda-%d" % snapshot_index);
output_nu_square_path = os.path.join(model_directory, "test-nu_square-%d" % snapshot_index);
evaluate_snapshot(input_snapshot_path, test_docs, output_lambda_path, output_nu_square_path)
else:
for model_snapshot in os.listdir(model_directory):
if not model_snapshot.startswith("model-"):
continue;
snapshot_index = int(model_snapshot.split("-")[-1]);
input_snapshot_path = os.path.join(model_directory, model_snapshot);
output_lambda_path = os.path.join(model_directory, "test-lambda-%d" % snapshot_index);
output_nu_square_path = os.path.join(model_directory, "test-nu_square-%d" % snapshot_index);
evaluate_snapshot(input_snapshot_path, test_docs, output_lambda_path, output_nu_square_path)
def evaluate_snapshot(input_snapshot_path, test_docs, output_lambda_path, output_nu_square_path):
# import hybrid, monte_carlo, variational_bayes;
lda_inferencer = pickle.load(open(input_snapshot_path, "rb"));
# print 'successfully load model snapshot %s...' % (os.path.abspath(input_snapshot_path));
log_likelihood, lambda_values, nu_square_values = lda_inferencer.inference(test_docs);
print("held-out likelihood of snapshot %s is %g" % (os.path.abspath(input_snapshot_path), log_likelihood));
numpy.savetxt(output_lambda_path, lambda_values);
numpy.savetxt(output_nu_square_path, nu_square_values);
if __name__ == '__main__':
main()