From e3af2251e6eabbd0a7df2891b00d714b00981117 Mon Sep 17 00:00:00 2001 From: sergpolly Date: Fri, 5 Jan 2024 10:52:44 -0500 Subject: [PATCH] quick fix bugs in stats --- pairtools/lib/stats.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/pairtools/lib/stats.py b/pairtools/lib/stats.py index 78967be..1495eaa 100644 --- a/pairtools/lib/stats.py +++ b/pairtools/lib/stats.py @@ -176,17 +176,13 @@ def __getitem__(self, key, filter="no_filter"): # there is only genomic distance range of the bin that's left: (bin_range,) = k_fields # extract left border of the bin "1000000+" or "1500-6000": - dist_bin_left = ( + dist_bin_left = int( bin_range.strip("+") if bin_range.endswith("+") else bin_range.split("-")[0] ) - # get the index of that bin: - bin_idx = ( - np.searchsorted(self._dist_bins, int(dist_bin_left), "right") - 1 - ) # store corresponding value: - return self._stat[filter]["dist_freq"][dirs][bin_idx] + return self._stat[filter]["dist_freq"][dirs][dist_bin_left] else: raise ValueError( "{} is not a valid key: {} section implies 2 identifiers".format( @@ -337,20 +333,13 @@ def from_file(cls, file_handle): # there is only genomic distance range of the bin that's left: (bin_range,) = key_fields # extract left border of the bin "1000000+" or "1500-6000": - dist_bin_left = ( + dist_bin_left = int( bin_range.strip("+") if bin_range.endswith("+") else bin_range.split("-")[0] ) - # get the index of that bin: - bin_idx = ( - np.searchsorted( - stat_from_file._dist_bins, int(dist_bin_left), "right" - ) - - 1 - ) # store corresponding value: - stat_from_file._stat[default_filter][key][dirs][bin_idx] = int( + stat_from_file._stat[default_filter][key][dirs][dist_bin_left] = int( fields[1] ) else: @@ -446,10 +435,10 @@ def add_pair( if chrom1 == chrom2: self._stat[filter]["cis"] += 1 dist = np.abs(pos2 - pos1) - bin = self._dist_bins[ + dist_bin = self._dist_bins[ np.searchsorted(self._dist_bins, dist, "right") - 1 ] - self._stat[filter]["dist_freq"][strand1 + strand2][bin] += 1 + self._stat[filter]["dist_freq"][strand1 + strand2][dist_bin] += 1 if dist >= 1000: self._stat[filter]["cis_1kb+"] += 1 if dist >= 2000: @@ -702,17 +691,19 @@ def flatten(self, filter="no_filter"): if (k == "dist_freq") and v: for i in range(len(self._dist_bins)): for dirs, freqs in v.items(): + dist = self._dist_bins[i] # last bin is treated differently: "100000+" vs "1200-3000": - if i != len(self._dist_bins) - 1: - dist = self._dist_bins[i] + if i < len(self._dist_bins) - 1: dist_next = self._dist_bins[i + 1] formatted_key = self._KEY_SEP.join( ["{}", "{}-{}", "{}"] ).format(k, dist, dist_next, dirs) - else: + elif i == len(self._dist_bins) - 1: formatted_key = self._KEY_SEP.join( ["{}", "{}+", "{}"] ).format(k, dist, dirs) + else: + raise ValueError("There is a mismatch between dist_freq bins in the instance") # store key,value pair: flat_stat[formatted_key] = freqs[dist] elif (k in ["pair_types", "dedup", "chromsizes"]) and v: