-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathphacelia.py
349 lines (282 loc) · 12.5 KB
/
phacelia.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# Author: Mahder Teka
# PHACELIA
# Last Modified: Feb, 2017
#This program takes a fasta file as in input and generates 600 features based on 10 Physcial Chemical Properties
#Selects 54 features based to be used in analysis and uses a logistic regression classification model to identify Recently infected HCV Patients
#Scoring function used to make predictions was developed by Zoya Dimitrova
import argparse
from argparse import FileType, ArgumentDefaultsHelpFormatter
from itertools import tee, islice, chain
from math import frexp
from collections import defaultdict
from pkg_resources import resource_filename, resource_stream
import io
import sys
import csv
import re
import numpy as np
from toolz.itertoolz import first#,second,tail
from sklearn.externals import joblib
from phacelia.acc import read_and_filter, binned_auto_covariance,get_cutoff_vect
from phacelia.util.phyche_file import parse_nucl_index, smart_open
class WeightedAverage(object):
"""\
Maintain a weighted average using long-integers to accumulate intermediate
sums exactly.
From code.activestate.com/recipes/393090
"""
def __init__(self):
# numerator tracking values
self.n_tmant = int(0)
self.n_texp = 0
# denominator tracking values
self.d_tmant = int(0)
self.d_texp = 0
@staticmethod
def _add(tmant, texp, val):
mant,exp = frexp(val)
mant,exp = int(mant * 2.0 ** 53),exp-53
if texp > exp:
tmant <<= texp - exp
texp = exp
else:
mant <<= exp - texp
tmant += mant
return tmant,texp
@staticmethod
def _val(tmant, texp):
return float(str(tmant)) * 2.0 ** texp
@property
def avg(self):
return self._val(self.n_tmant, self.n_texp) / self._val(self.d_tmant, self.d_texp)
def accumulate(self, val, weight):
self.n_tmant, self.n_texp = self._add(self.n_tmant, self.n_texp, val * weight)
self.d_tmant, self.d_texp = self._add(self.d_tmant, self.d_texp, weight)
def binned_classification(seqs, k, phyche_vals, cutoff_vect, classifier,
chunk_size=10, new_ac=False):
"""
"""
seqs = (s for s in seqs if s)
seqs1, seqs2 = tee(seqs)
features = binned_auto_covariance(seqs1, phyche_vals, 2, cutoff_vect, new_ac=new_ac)
# can't use partition all here, xarray does not like that. So we need this
# loop rather than the cleaner 'for c in partition_all('
# run make_ac_vec for DiNucleotide auto covariance
done = False
while not done:
feature_vect = list(islice(features, chunk_size))
if not feature_vect:
break
probs = classifier.predict_proba(np.vstack(feature_vect))
for i in range(probs.shape[0]):
s = next(seqs2)
for j, c in enumerate(classifier.classes_):
s.annotations[c.decode('utf-8')] = probs[i][j]
yield s
if first(probs.shape) < chunk_size:
done = True
RECENCY_GRAYZONE_BEGIN = 0.416912205235602
RECENCY_GRAYZONE_END = 0.487643239921338
def recency_bin(samples):
#Runs acc script to create file with 600 features
# "python acc.py %s output.csv -lag 60 DNA DAC -e dnaDAC_10.txt -f csv" % filename
with resource_stream('phacelia', 'data/dnaDAC_10.txt') as in_fd:
phyche_vals = parse_nucl_index(in_fd)
with resource_stream('phacelia', 'data/cutoff.csv') as c_fd:
cutoff_vect = get_cutoff_vect(c_fd)
clf = joblib.load(resource_filename('phacelia', 'data/PHACELIA.pkl'))
avgs = defaultdict(WeightedAverage)
for sample in samples:
seqs = binned_classification(
sample,
2,
phyche_vals,
cutoff_vect,
clf,
new_ac=False)
for s in seqs:
k = (s.annotations['sample'],s.annotations['genotype'])
avgs[k].accumulate(s.annotations['Chronic'], s.annotations['freq'])
for (pat,gen),v in avgs.items():
classification = 0
if v.avg > RECENCY_GRAYZONE_BEGIN:
classification = 1
if v.avg > RECENCY_GRAYZONE_END:
classification = 2
yield pat,gen,v.avg,classification
def annotate_sequence(seq, field_dict=None, name_regex=None):
"""
Given a sequence object, parse its description and use a mapping of fields to
annotate the sequence object.
Parameters
----------
seq : SeqRecord object
The sequence to annotate, as generated by Bio.SeqIO.parse().
field_dict : dictionary
A dictionary mapping field names to 1-indexed field numbers. Negative
numbers are allowed, and indicate field numbers beginning from the end
of the description.
Returns
-------
seq : SeqRecord object
The annotated SeqRecord object.
"""
if name_regex:
f = name_regex.match(seq.description)
if not f:
print('Error splitting name: ', seq.description, file=sys.stderr)
seq.annotations['sample'] = 'undetermined'
seq.annotations['genotype'] = 'undetermined'
seq.annotations['frequency'] = 0
else:
f = f.groupdict()
seq.annotations['frequency'] = int(f['f'])
seq.annotations['genotype'] = f['t']
seq.annotations['sample'] = f['i']
return seq
else:
fields = seq.description.split(field_dict['sep'])
seq.annotations['frequency'] = int(fields[field_dict['frequency']])
seq.annotations['genotype'] = fields[field_dict['genotype']]
seq.annotations['sample'] = fields[field_dict['sample']]
return seq
def zero_shift(field_dict):
# Make sure required values exist, and shift 1-indexed fields to 0-indexed
field_dict['sep'] = field_dict.get('sep', '_')
reqs = [('sample', 1), ('genotype', -2), ('frequency', -1)]
for k, v in reqs:
if field_dict.get(k, v) > 0:
field_dict[k] = field_dict.get(k, v) - 1
else:
field_dict[k] = field_dict.get(k, v)
return field_dict
def main(args):
"""
"""
with resource_stream('phacelia', 'data/dnaDAC_10.txt') as in_fd:
phyche_vals = parse_nucl_index(in_fd)
with resource_stream('phacelia', 'data/cutoff.csv') as c_fd:
cutoff_vect = get_cutoff_vect(io.TextIOWrapper(c_fd, encoding='utf-8'))
clf = joblib.load(resource_filename('phacelia', 'data/PHACELIA.pkl'))
# Read in all fasta files, filtering out sequences with N's
seqs = chain(*[read_and_filter(fn) for fn in args.input_files])
field_dict = zero_shift(args.field_dict)
# Parse the sequence headers and use them to annotate the sequences
if args.name_regex:
seqs = map(lambda s: annotate_sequence(s, name_regex=args.name_regex), seqs)
else:
seqs = map(lambda s: annotate_sequence(s, field_dict=field_dict), seqs)
# Remove sequences whose frequency is below the cutoff
seqs = filter(lambda s: s.annotations['frequency'] >= args.freq_cutoff, seqs)
# Predict probability of 'chronic' for each sequence
seqs = binned_classification(seqs, 2, phyche_vals, cutoff_vect, clf,
new_ac=False)
# Use probabilities to classify sample or sequence as 'Chronic' or 'Recent'
probs = defaultdict(WeightedAverage)
if args.unit == 'sample':
for s in seqs:
key = '_'.join([s.annotations['sample'], s.annotations['genotype']])
probs[key].accumulate(s.annotations['Chronic'], s.annotations['frequency'])
elif args.unit == 'sequence':
for s in seqs:
key = s.description
probs[key].accumulate(s.annotations['Chronic'], 1)
# Write results to a file
with smart_open(args.output_file) as f_out:
print('identifier', 'probability', 'prediction', sep=',', file=f_out)
for ID, pr in probs.items():
if pr.avg < RECENCY_GRAYZONE_BEGIN:
prediction = 'Recent'
elif pr.avg < RECENCY_GRAYZONE_END:
prediction = 'Indeterminate'
else:
prediction = 'non-Recent'
print(ID, "{value:.{digits}f}".format(value=pr.avg, digits=args.round),
prediction, sep=',', file=f_out)
def main_features(args):
with resource_stream('phacelia', 'data/dnaDAC_10.txt') as in_fd:
phyche_vals = parse_nucl_index(in_fd)
with resource_stream('phacelia', 'data/cutoff.csv') as c_fd:
cutoff_vect = get_cutoff_vect(c_fd)
seqs = list(chain(*[read_and_filter(fn) for fn in args.input_files]))
features = binned_auto_covariance(seqs, phyche_vals, 2, cutoff_vect,
new_ac=False)
with smart_open(args.output_file, newline='') as f_out:
writer = csv.writer(f_out)
header = ['seq_ID'] + [ ' L'.join(map(str, [r['index'], r['lag']])) for i,r in cutoff_vect.iterrows() ]
header = ['_'.join(s.split()) for s in header]
writer.writerow(header)
for s, f in zip(seqs,features):
writer.writerow([s.description] + list(f.data))
class StoreDictKeyPair(argparse.Action):
"""
A custome argparse.Action class to parse key-value pairs and add them to a
dictionary. Thanks to @storm_m2138 on stackoverflow:
https://stackoverflow.com/questions/29986185/python-argparse-dict-arg/42355279
"""
def __init__(self, option_strings, dest, nargs=None, **kwargs):
self._nargs = nargs
super(StoreDictKeyPair, self).__init__(option_strings, dest, nargs=nargs, **kwargs)
def __call__(self, parser, namespace, values, option_string=None):
my_dict = {}
for kv in values:
k, v = kv.split('=')
if k == 'sep':
my_dict[k] = v
else:
try:
my_dict[k] = int(v)
except ValueError:
raise ValueError('Invalid field value. Field values must be integers.')
else:
setattr(namespace, self.dest, my_dict)
def subcommand(subparsers):
parser = subparsers.add_parser('classify',
formatter_class=ArgumentDefaultsHelpFormatter,
help='Classify haplotypes as recent or '
'non-recent variants.')
parser.set_defaults(func=main)
parser.add_argument('input_files', type=FileType('r'), nargs='+',
help='The fasta files of sequences to classify.')
parser.add_argument('-u', '--unit', choices=['sequence', 'sample'],
default='sequence', help='Unit of classification.')
parser.add_argument('-n', '--new_ac', default=False, dest='new_ac',
action='store_true', help='Use the new AC calculation.')
parser.add_argument('-f', '--freq-cutoff', default=10, type=int,
help='Ignore sequences with frequency less than this.')
parser.add_argument('--fields', action=StoreDictKeyPair,
dest='field_dict', nargs='+', metavar='KEY=VAL',
default={'sample': 1, 'genotype': -2, 'frequency': -1, 'sep': '_'},
# default=None,
help='Key-value pairs indicating the field numbers in '
'the fasta defline to use for sample ID, genotype, '
'and sequence frequency. Use 1 for the first field, '
'etc., and use negative numbers to count from the '
'end of the defline. Example: sample=1')
parser.add_argument('-s', '--name-regex', default=None,
# default=re.compile(r'(?P<i>.+)_(?P<t>.+)_(?P<f>\d+)'),
type=re.compile, dest='name_regex',
help='The regex used to split the name into sections. '
'Should have named groups:\n'
'i: identifer, t: type, f: frequency. This option is '
"an alternative to the '--fields' argument. If this "
"is used, the '--fields' options will be ignored.")
parser.add_argument('-r', '--round', type=int, default=6,
help='The number of decimal places to round predicted '
'probabilities to.')
parser.add_argument('-o', '--output_file',
help=('Optional output file name. If none is specified,'
'classification results will be printed to stdout.'))
parser = subparsers.add_parser('features',
formatter_class=ArgumentDefaultsHelpFormatter,
help='Output the features that would go into '
'a model.')
parser.set_defaults(func=main_features)
parser.add_argument('input_files', type=FileType('r'), nargs='+',
help='The fasta files of sequences to construct features '
'from.')
parser.add_argument('-n', '--new_ac', default=False, dest='new_ac',
action='store_true', help='Use the new AC calculation.')
parser.add_argument('-o', '--output_file',
help='Optional output file name. If none is specified,'
'classification results will be printed to stdout.')