Skip to content

Commit

Permalink
QR-STAR
Browse files Browse the repository at this point in the history
  • Loading branch information
ytabatabaee committed Oct 14, 2022
1 parent 5d7d06c commit be957fd
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 28 deletions.
52 changes: 31 additions & 21 deletions qr/fitness_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
from qr.adr_theory import *


def cost(u, indices, tree_shape, cost_func):
def cost(u, indices, tree_shape, cost_func, k, q_size, shape_coef, abratio):
"""
Given the probability distribution of unrooted quintet trees u,
the partial order of a tree R, and the type of the fitness function,
returns the cost Cost(R, u)
:param shape_coef: coefficient for shape penalty term
:param k: number of genes
:param n: number of taxa in the species tree
:param np.ndarray u: unrooted quintet tree probability distribution
:param list indices: partial order on tree R in the form of a list of indices
:param str tree_shape: topological shape of tree R
Expand All @@ -16,43 +19,50 @@ def cost(u, indices, tree_shape, cost_func):
invariant_score = 0
inequality_score = 0
equiv_classes, inequalities = get_partial_order(tree_shape)
est_shape = topological_shape(u, k, q_size)
# similarity inside equiv classes
if cost_func == 'inq':
invariant_score = 0
else:
for c in equiv_classes:
intraclass_sim = 0
for i in range(len(c)):
for j in range(len(c)):
intraclass_sim += invariant_metric(u[indices[c[i]]], u[indices[c[j]]])
for c in equiv_classes:
intraclass_sim = 0
for i in range(len(c)):
for j in range(len(c)):
intraclass_sim += invariant_metric(u[indices[c[i]]], u[indices[c[j]]])
if cost_func == 'star':
invariant_score += intraclass_sim
else:
invariant_score += intraclass_sim / (len(c))

# distance between equiv classes
for ineq in inequalities:
interclass_distance = 0
for i in equiv_classes[ineq[0]]:
for j in equiv_classes[ineq[1]]:
interclass_distance += inequality_metric(u[indices[j]], u[indices[i]])
if cost_func == 'inq':
if cost_func == 'star':
inequality_score += interclass_distance
else:
inequality_score += interclass_distance / (len(equiv_classes[ineq[0]]))

return invariant_score + inequality_score
return invariant_score * abratio + inequality_score + shape_coef * int(est_shape != tree_shape)


def topological_shape(u, k, q_size):
u_sorted = np.sort(u)
threshold = A(k, q_size) # this could correspond to a lower bound on f(r) of the species tree
if u_sorted[6] - u_sorted[5] < threshold:
return 'p'
elif (u_sorted[6] - u_sorted[5] >= threshold) and (u_sorted[8] - u_sorted[7] < threshold):
return 'b'
else:
return 'c'


def A(k, q_size):
return 2 * np.sqrt(np.log(30 * k * q_size) / (2 * k))


def invariant_metric(a, b):
"""
Invariant penalty term
:param float a, b: probabilities of two trees
:rtype: float
"""
return np.abs(a - b)


def inequality_metric(a, b):
"""
Inequality penalty term
:param float a, b: probabilities of two trees
:rtype: float
"""
return (a - b) * (a > b)
2 changes: 1 addition & 1 deletion qr/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = '1.2.3'
__version__ = '1.2.4'

36 changes: 30 additions & 6 deletions quintet_rooting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from table_five import TreeSet

from qr.adr_theory import *
from qr.fitness_cost import cost
from qr.fitness_cost import *
from qr.quintet_sampling import *
from qr.utils import *
from qr.version import __version__
Expand All @@ -23,7 +23,9 @@ def main(args):
sampling_method = args.samplingmethod.lower()
random.seed(args.seed)
cost_func = args.cost.lower()
shape_coef = args.coef
mult_le = args.multiplicity
abratio = args.abratio

header = """*********************************
* Quintet Rooting """ + __version__ + """ *
Expand All @@ -32,6 +34,9 @@ def main(args):

# reading gene tree and unrooted species tree topology files
tns = dendropy.TaxonNamespace()
#true_s_tree = dendropy.Tree.get(path=species_tree_path, schema='newick',
# taxon_namespace=tns, rooting="force-rooted")
#print(true_s_tree)
unrooted_species = dendropy.Tree.get(path=species_tree_path, schema='newick',
taxon_namespace=tns, rooting="force-unrooted", suppress_edge_lengths=True)
if len(tns) < 5:
Expand Down Expand Up @@ -86,6 +91,8 @@ def main(args):
quintet_unrooted_indices = np.zeros(len(sample_quintet_taxa), dtype=int)
quintets_r_all = []

#est_shape = None
#true_shape = None
for j in range(len(sample_quintet_taxa)):
q_taxa = sample_quintet_taxa[j]
quintets_u = [
Expand All @@ -95,14 +102,24 @@ def main(args):
dendropy.Tree.get(data=map_taxon_namespace(str(q), q_taxa) + ';', schema='newick', rooting='force-rooted',
taxon_namespace=tns) for q in rooted_quintets_base]
subtree_u = unrooted_species.extract_tree_with_taxa_labels(labels=q_taxa, suppress_unifurcations=True)
#subtree_t = true_s_tree.extract_tree_with_taxa_labels(labels=q_taxa, suppress_unifurcations=True)
quintet_counts = np.asarray(gene_trees.tally_single_quintet(q_taxa))
quintet_normalizer = sum(quintet_counts) if args.normalized else len(gene_trees)
quintet_tree_dist = quintet_counts
if quintet_normalizer != 0:
quintet_tree_dist = quintet_tree_dist / quintet_normalizer
quintet_unrooted_indices[j] = get_quintet_unrooted_index(subtree_u, quintets_u)
#est_shape = topological_shape(quintet_tree_dist, len(gene_trees), len(taxon_set))
#for i in range(len(quintets_r)):
# if dendropy.calculate.treecompare.symmetric_difference(quintets_r[i], true_s_tree) == 0:
# true_shape = idx_2_unlabeled_topology(i)
# break
#print(quintet_tree_dist)
#sys.stdout.write('Estimated Shape: \n%s \n' % est_shape)
#sys.stdout.write('Real Shape: \n%s \n' % true_shape)
quintet_scores[j] = compute_cost_rooted_quintets(quintet_tree_dist, quintet_unrooted_indices[j],
rooted_quintet_indices, cost_func)
rooted_quintet_indices, cost_func, len(gene_trees),
len(sample_quintet_taxa), shape_coef, abratio)
quintets_r_all.append(quintets_r)

sys.stdout.write('Preprocessing time: %.2f sec\n' % (time.time() - proc_time))
Expand Down Expand Up @@ -142,7 +159,7 @@ def main(args):
sys.stdout.write('Total execution time: %.2f sec\n' % (time.time() - st_time))


def compute_cost_rooted_quintets(u_distribution, u_idx, rooted_quintet_indices, cost_func):
def compute_cost_rooted_quintets(u_distribution, u_idx, rooted_quintet_indices, cost_func, k, q_size, shape_coef, abratio):
"""
Scores the 7 possible rootings of an unrooted quintet
:param np.ndarray u_distribution: unrooted quintet tree probability distribution
Expand All @@ -157,7 +174,7 @@ def compute_cost_rooted_quintets(u_distribution, u_idx, rooted_quintet_indices,
idx = rooted_tree_indices[i]
unlabeled_topology = idx_2_unlabeled_topology(idx)
indices = rooted_quintet_indices[idx]
costs[i] = cost(u_distribution, indices, unlabeled_topology, cost_func)
costs[i] = cost(u_distribution, indices, unlabeled_topology, cost_func, k, q_size, shape_coef, abratio)
return costs


Expand Down Expand Up @@ -204,7 +221,7 @@ def parse_args():
"linear)", required=False, default='d')

parser.add_argument("-c", "--cost", type=str,
help="cost function (INQ for inequalities only)",
help="cost function (STAR for running QR*)",
required=False, default='d')

parser.add_argument("-cfs", "--confidencescore", action='store_true',
Expand All @@ -215,7 +232,14 @@ def parse_args():
required=False, default=1)

parser.add_argument("-norm", "--normalized", action='store_true',
help="normalization for unresolved gene trees or missing taxa")
help="normalization for unresolved gene trees or missing taxa",
required=False, default=False)

parser.add_argument("-coef", "--coef", type=float,
help="coefficient for shape penalty term", required=False, default=0)

parser.add_argument("-abratio", "--abratio", type=float,
help="Ratio between invariant and inequality penalties used in QR*", required=False, default=1)

parser.add_argument("-rs", "--seed", type=int,
help="random seed", required=False, default=1234)
Expand Down

0 comments on commit be957fd

Please sign in to comment.