Skip to content

Commit

Permalink
rename to SelfTuningCompoundDistribution + start debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
rbouckaert committed Oct 21, 2024
1 parent 500a7e3 commit b4baa53
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ private static void fetchBeagleSettings() {

@Override
public void initAndValidate() {
List<Alignment> Alignments = new ArrayList<>();
List<Alignment> alignments = new ArrayList<>();
// List<BranchRateModel> branchModels = new ArrayList<>();
siteRateModels = new ArrayList<>();

Expand Down Expand Up @@ -231,7 +231,7 @@ public void initAndValidate() {
tl.getID() + " and " + likelihoodsInput.get().get(0).getID());
}

Alignments.add(tl.dataInput.get());
alignments.add(tl.dataInput.get());
//branchModels.add(tl.branchRateModelInput.get());
siteRateModels.add((SiteModel) tl.siteModelInput.get());

Expand All @@ -246,7 +246,7 @@ public void initAndValidate() {
}

try {
initialise(tree, Alignments, /*branchModels,*/ siteRateModels, useAmbiguities, useTipLikelihoods, rescalingScheme, delayScalingUntillUnderflowInput.get());
initialise(tree, alignments, branchRateModel, siteRateModels, useAmbiguities, useTipLikelihoods, rescalingScheme, delayScalingUntillUnderflowInput.get());
} catch (DelegateTypeException e) {
e.printStackTrace();
}
Expand Down Expand Up @@ -277,15 +277,15 @@ static public PartialsRescalingScheme getRescalingScheme(GenericTreeLikelihood t
*
* @param tree Used for configuration - shouldn't be watched for changes
* @param branchModels Specifies a list of branch models for each partition
* @param Alignments List of Alignments comprising each partition
* @param alignments List of Alignments comprising each partition
* @param siteRateModels A list of siteRateModels for each partition
* @param useAmbiguities Whether to respect state ambiguities in data
*/


public void initialise(TreeInterface tree,
List<Alignment> Alignments,
//List<BranchRateModel> branchModels,
List<Alignment> alignments,
BranchRateModel branchRateModel,
List<SiteModel> siteRateModels,
boolean useAmbiguities,
boolean useTipLikelihoods,
Expand All @@ -295,16 +295,17 @@ public void initialise(TreeInterface tree,


//setID(Alignments.get(0).getID());

this.Alignments = Alignments;
this.dataType = Alignments.get(0).getDataType();
this.tree = tree;
this.branchRateModel = branchRateModel;
this.Alignments = alignments;
this.dataType = alignments.get(0).getDataType();
stateCount = dataType.getStateCount();

partitionCount = Alignments.size();
partitionCount = alignments.size();
patternCounts = new int[partitionCount];
int total = 0;
int k = 0;
for (Alignment Alignment : Alignments) {
for (Alignment Alignment : alignments) {
assert(Alignment.getDataType().equals(this.dataType));
patternCounts[k] = Alignment.getPatternCount();
total += patternCounts[k];
Expand Down Expand Up @@ -335,9 +336,11 @@ public void initialise(TreeInterface tree,

// SiteRateModels determine the rates per category (for site-heterogeneity models).
// There can be either one per partition or one shared across all partitions
assert(siteRateModels.size() == 1 || (siteRateModels.size() == Alignments.size()));
assert(siteRateModels.size() == 1 || (siteRateModels.size() == alignments.size()));

//this.siteRateModels.addAll(siteRateModels);
if (this.siteRateModels.size() == 0l) {
this.siteRateModels.addAll(siteRateModels);
}
this.categoryCount = this.siteRateModels.get(0).getCategoryCount();

nodeCount = tree.getNodeCount();
Expand Down Expand Up @@ -651,7 +654,7 @@ public void initialise(TreeInterface tree,

int j = 0;
k = 0;
for (Alignment Alignment : Alignments) {
for (Alignment Alignment : alignments) {
int[] pw = Alignment.getWeights();
for (int i = 0; i < Alignment.getPatternCount(); i++) {
patternPartitions[k] = j;
Expand All @@ -663,17 +666,17 @@ public void initialise(TreeInterface tree,

Log.warning(" " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood.");
Log.warning.println(" " + (useTipLikelihoods ? "Using" : "Ignoring") + " character uncertainty in tree likelihood.");
String patternCountString = "" + Alignments.get(0).getPatternCount();
for (int i = 1; i < Alignments.size(); i++) {
patternCountString += ", " + Alignments.get(i).getPatternCount();
String patternCountString = "" + alignments.get(0).getPatternCount();
for (int i = 1; i < alignments.size(); i++) {
patternCountString += ", " + alignments.get(i).getPatternCount();
}
Log.warning(" With " + Alignments.size() + " partitions comprising " + patternCountString + " unique site patterns");
Log.warning(" With " + alignments.size() + " partitions comprising " + patternCountString + " unique site patterns");

for (int i = 0; i < tipCount; i++) {
if (useAmbiguities || useTipLikelihoods) {
setPartials(beagle, Alignments, tree.getNode(i));
setPartials(beagle, alignments, tree.getNode(i));
} else {
setStates(beagle, Alignments, tree.getNode(i));
setStates(beagle, alignments, tree.getNode(i));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
+ "types of BEAGLE instances, and"
+ "whether to use BEAGLE 2 or 3 API (i.e. seperate TreeLikelihoods or MultiPartitionTreeLikelihood). "
+ "Self tuning replacement of CompounDistribution with id=='likelihood'.")
public class BenchmarkingCompoundDistribution extends Distribution {
public class SelfTuningCompoundDistribution extends Distribution {
// no need to make this input REQUIRED. If no distribution input is
// specified the class just returns probability 1.
final public Input<List<Distribution>> pDistributions =
Expand All @@ -48,12 +48,13 @@ public class BenchmarkingCompoundDistribution extends Distribution {
final public Input<Boolean> ignoreInput = new Input<>("ignore", "ignore all distributions and return 1 as distribution (default false)", false);


final public Input<Long> swithcCountInput = new Input<>("switchCount","number of times to calculate likelihood before switching configuration", 1000l);
final public Input<Long> swithcCountInput = new Input<>("switchCount", "number of milli seconds to calculate likelihood before switching configuration", 1000l);
final public Input<Long> reconfigCountInput = new Input<>("reconfigCount", "number of times to calculate likelihood before self tuning again", 100000l);


class Configuration {
long nrOfSamples;
long totalRunTime;
double totalRunTime;
int threadCount;

Configuration(int threadCount) {
Expand Down Expand Up @@ -103,6 +104,11 @@ void reset() {
}
calculateLogP();
}

@Override
public String toString() {
return "Separate partition " + threadCount + " thread" + (threadCount > 1 ? "s" :"");
}
}

class MultiPartitionConfiguration extends Configuration {
Expand All @@ -114,6 +120,9 @@ class MultiPartitionConfiguration extends Configuration {
}

double calculateLogP() {
if (nrOfSamples == 1) {
switchTime = System.currentTimeMillis();
}
nrOfSamples++;
return mpTreeLikelihood.calculateLogP();
}
Expand All @@ -124,6 +133,11 @@ void reset() {
tree.setEverythingDirty(true);
calculateLogP();
}

@Override
public String toString() {
return "Multi Partition";
}

}

Expand Down Expand Up @@ -163,6 +177,8 @@ public void initAndValidate() {
}

ignore = ignoreInput.get();

switchCount = swithcCountInput.get();

if (pDistributions.get().size() == 0) {
logP = 0;
Expand All @@ -171,6 +187,8 @@ public void initAndValidate() {
MultiPartitionTreeLikelihood mpTreeLikelihood = createMultiPartitionTreeLikelihood();

currentConfiguration = initConfigurations(mpTreeLikelihood);
Log.warning("Starting with " + currentConfiguration.toString());

}


Expand All @@ -182,7 +200,7 @@ private MultiPartitionTreeLikelihood createMultiPartitionTreeLikelihood() {
boolean useAmbiguities = false;
boolean useTipLikelihoods = false;
PartialsRescalingScheme rescalingScheme = null;
List<Alignment> Alignments = new ArrayList<>();
List<Alignment> alignments = new ArrayList<>();
List<SiteModel> siteRateModels = new ArrayList<>();
List<GenericTreeLikelihood> distributions = new ArrayList<>();
String dataType = null;
Expand Down Expand Up @@ -235,17 +253,17 @@ private MultiPartitionTreeLikelihood createMultiPartitionTreeLikelihood() {
+ "All scaling must be the same. -- MultiPartitionTreeLikelihood not considered");
}

Alignments.add(tl.dataInput.get());
siteRateModels.add((SiteModel) tl.siteModelInput.get());
}
alignments.add(tl.dataInput.get());
siteRateModels.add((SiteModel) tl.siteModelInput.get());
distributions.add(tl);
}
}

MultiPartitionTreeLikelihood mpt = new MultiPartitionTreeLikelihood();
mpt.likelihoodsInput.get().addAll(distributions);
try {
mpt.initialise(tree, Alignments, siteRateModels, useAmbiguities, useTipLikelihoods, rescalingScheme, useTipLikelihoods);
mpt.initialise(tree, alignments, branchRateModel, siteRateModels, useAmbiguities, useTipLikelihoods, rescalingScheme, useTipLikelihoods);
} catch (Exception e) {
e.printStackTrace();
return null;
Expand All @@ -267,7 +285,7 @@ private Configuration initConfigurations(MultiPartitionTreeLikelihood mpTreeLike
}


for (int threadCount = minNrOfThreadsInput.get(); threadCount < maxNrOfThreads; threadCount++) {
for (int threadCount = minNrOfThreadsInput.get(); threadCount <= maxNrOfThreads; threadCount++) {
Configuration oneThreadCfg = new Configuration(threadCount);
configurations.add(oneThreadCfg);
}
Expand All @@ -293,16 +311,19 @@ private void switchConfiguration() {
bestConfigurationSoFar = cfg0;
for (Configuration cfg : configurations) {
double score = cfg.totalRunTime / cfg.nrOfSamples;
Log.warning(cfg.toString() + ": " + cfg.totalRunTime + "/" + cfg.nrOfSamples + " " + cfg.totalRunTime / cfg.nrOfSamples);
if (score < best) {
bestConfigurationSoFar = cfg0;
bestConfigurationSoFar = cfg;
best = score;
}
}
return;

currentConfiguration = bestConfigurationSoFar;
} else {
// continue with next configuration
currentConfiguration = configurations.get(i);
}

// continue with next configuration
currentConfiguration = configurations.get(i);
if (exec != null) {
exec.shutdown();
}
Expand All @@ -311,7 +332,9 @@ private void switchConfiguration() {
}

currentConfiguration.reset();
switchTime = System.currentTimeMillis();

Log.warning("Switching to " + currentConfiguration.toString());
switchTime = System.currentTimeMillis();
}


Expand All @@ -326,15 +349,45 @@ public double calculateLogP() {
return logP;
}
nrOfSamples++;

if (nrOfSamples % switchCount == 0) {
switchConfiguration();
}
// if (System.currentTimeMillis() - switchTime > switchCount) {
// switchConfiguration();
// }


if (nrOfSamples % reconfigCountInput.get() == 0) {
restartTuning();
}
currentConfiguration.nrOfSamples++;
logP = currentConfiguration.calculateLogP();
return logP;
}

class CoreRunnable implements java.lang.Runnable {
private void restartTuning() {
bestConfigurationSoFar = null;
for (Configuration cfg : configurations) {
cfg.nrOfSamples = 0;
cfg.totalRunTime = 0;
}

currentConfiguration = configurations.get(0);
if (exec != null) {
exec.shutdown();
}
if (currentConfiguration.threadCount > 1) {
exec = Executors.newFixedThreadPool(currentConfiguration.threadCount);
}

currentConfiguration.reset();

Log.warning("Switching to " + currentConfiguration.toString());
switchTime = System.currentTimeMillis();
}

class CoreRunnable implements java.lang.Runnable {
Distribution distr;

CoreRunnable(Distribution core) {
Expand Down Expand Up @@ -475,5 +528,13 @@ public List<Input<?>> listInputs() {
return list;
}


@Override
public void restore() {
if (currentConfiguration.nrOfSamples <= 2) {
currentConfiguration.reset();
}
super.restore();
}

} // class CompoundDistribution
1 change: 1 addition & 0 deletions version.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
<provider classname="beastlabs.evolution.likelihood.SupertreeLikelihood"/>
<provider classname="beastlabs.evolution.likelihood.TraitedTreeLikelihood"/>
<provider classname="beastlabs.evolution.likelihood.MultiPartitionTreeLikelihood"/>
<provider classname="beastlabs.evolution.likelihood.SelfTuningCompoundDistribution"/>
<provider classname="beastlabs.evolution.operators.AdaptableVarianceMultivariateNormalOperator"/>
<provider classname="beastlabs.evolution.operators.AttachAndUniformOperator"/>
<provider classname="beastlabs.evolution.operators.AttachOperator"/>
Expand Down

0 comments on commit b4baa53

Please sign in to comment.