-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathclient_server_fake.py
134 lines (122 loc) · 4.84 KB
/
client_server_fake.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import numpy as np
import os
import pyhe_client
import subprocess
import time
from consts import out_client_name, out_final_name, inference_times_name
from mnist_util import client_argument_parser
from utils import client_data
from utils.main_utils import array_str
from utils.time_utils import log_timing
from utils.main_utils import round_array
def run_client(FLAGS, data):
port = FLAGS.port
logits = np.random.uniform(-1000, 1000, 10)
print(logits, FLAGS.r_star)
r_rstar = logits + FLAGS.r_star
# inference_end = time.time()
# print(f"Inference time: {inference_end - inference_start}s")
# with open(inference_times_name, 'a') as outfile:
# outfile.write(str(inference_end - inference_start))
# outfile.write('\n')
# print('r_rstar (r-r*): ', array_str(r_rstar))
rstar = FLAGS.r_star
if rstar is None:
raise ValueError('r_star should be provided but was None.')
r_rstar = round_array(x=r_rstar, exp=FLAGS.round_exp)
print('rounded r_rstar (r-r*): ', array_str(r_rstar))
print("Writing out logits file to txt.")
with open(f'{out_client_name}{port}privacy.txt', 'w') as outfile:
for val in r_rstar.flatten():
outfile.write(f"{int(val)}\n")
# do 2 party computation with each Answering Party
msg = 'starting 2pc with Answering Party'
print(msg)
log_timing(stage='client:' + msg,
log_file=FLAGS.log_timing_file)
# completed = {port: False for port in flags.ports}
max_t = time.time() + 100000
while not os.path.exists(f"{out_final_name}{port}privacy.txt"):
# print(f'client starting 2pc with port: {port}, searching file: {out_final_name}{port}privacy.txt')
process = subprocess.Popen(
['./gc-emp-test/bin/multi_label_softmax', '2', '12345',
f'{out_client_name}{port}privacy.txt'])
# time.sleep(1)
if time.time() > max_t:
raise ValueError("Step 1' of protocol never finished. Issue.")
log_timing(stage='client:finished 2PC',
log_file=FLAGS.log_timing_file)
return r_rstar, rstar
# print("Prepping for 2pc with CSP")
#
# r_rstars = []
# for port in flags.ports:
# with open(f'output{port}privacy.txt', 'r') as infile:
# r_rstar = []
# for line in infile:
# r_rstar.append(int(line))
# r_rstars.append(r_rstar)
# r_rstars = np.array(r_rstars, np.int64)
# print(r_rstars)
# print('done')
#
# if flags.final_call:
# fs = [f"output{port}privacy.txt" for port in flags.ports]
# array_sum = csp.sum_files(fs)
# print(array_sum)
# with open("output.txt", 'w') as outfile:
# for v in array_sum.flatten():
# outfile.write(f'{v}\n')
# csp_filenames = [f'noise{port}privacy.txt' for port in flags.ports]
# label = csp.get_histogram(
# client_filename='output.txt',
# csp_filenames=csp_filenames,
# csp_sum_filename='final.txt')
# print(label)
if __name__ == "__main__":
FLAGS, unparsed = client_argument_parser().parse_known_args()
if unparsed:
print("Unparsed flags:", unparsed)
exit(1)
if FLAGS.from_pytorch:
queries, labels, noisies = client_data.load_data(FLAGS.dataset_path)
query = queries[FLAGS.minibatch_id].transpose()
label = labels[FLAGS.minibatch_id]
noisy = noisies[FLAGS.minibatch_id]
(x_train, y_train, x_test, y_test) = client_data.load_mnist_data(0, 1)
query = x_test
else:
# (x_train, y_train), (x_test, y_test) = client_data.get_dataset(
# FLAGS.dataset)
# query = x_test
raise ValueError('must be from pytorch')
start_time = time.time()
print(query.shape)
r_rstar, rstar = run_client(FLAGS=FLAGS, data=query[None, ...].flatten("C"))
end_time = time.time()
print(f'step 1 runtime: {end_time - start_time}s')
log_timing('client_server:finish', log_file=FLAGS.log_timing_file)
# Check if stage 1 was executed correctly.
# if FLAGS.predict_labels_file is not None:
# port = FLAGS.ports[0]
# predict_labels_file = FLAGS.predict_labels_file + str(
# port) + '.npy'
# predict_labels = np.load(predict_labels_file)
# check_rstar_stage1(
# rstar=rstar,
# r_rstar=r_rstar,
# labels=predict_labels,
# port=port,
# )
# y_labels = labels.argmax(axis=1)
# print("y_test: ", y_labels)
#
# y_pred = y_pred_reshape.argmax(axis=1)
# print("y_pred: ", y_pred)
#
# correct = np.sum(np.equal(y_pred, y_labels))
# acc = correct / float(flags.batch_size)
# print("correct from original result: ", correct)
# print(
# "Accuracy original result (batch size", flags.batch_size, ") =",
# acc * 100.0, "%")