Skip to content

Commit

Permalink
optimise and test MultiPartitionTreeLikelihood
Browse files Browse the repository at this point in the history
  • Loading branch information
rbouckaert committed Oct 13, 2024
1 parent cd9b72f commit ae11950
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 633 deletions.
15 changes: 12 additions & 3 deletions examples/testMultiParitionTreeLikelihood.xml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ spec="Alignment">
<parameter id="freqParameter.s:secondhalf" spec="parameter.RealParameter" dimension="4" lower="0.0" name="stateNode" upper="1.0">0.25</parameter>
<parameter id="gammaShape.s:secondhalf" spec="parameter.RealParameter" lower="0.1" name="stateNode">1.0</parameter>
<parameter id="kappa.s:secondhalf" spec="parameter.RealParameter" lower="0.0" name="stateNode">2.0</parameter>
<parameter id="mutationRate.s:secondhalf" spec="parameter.RealParameter" lower="0.0" name="stateNode">1.0</parameter>
<parameter id="mutationRate.s:firsthalf" spec="parameter.RealParameter" lower="0.0" name="stateNode">1.0</parameter>
</state>
<init id="RandomTree.t:tree" spec="RandomTree" estimate="false" initial="@Tree.t:tree" taxa="@firsthalf">
<populationModel id="ConstantPopulation0.t:tree" spec="ConstantPopulation">
Expand Down Expand Up @@ -104,7 +106,7 @@ spec="Alignment">
<distribution id="MultiPartitionTreeLikelihood0" spec="CompoundDistribution" useThreads="true">
<distribution id="treeLikelihood.firsthalf" spec="TreeLikelihood" data="@firsthalf" tree="@Tree.t:tree">
<siteModel id="SiteModel.s:firsthalf" spec="SiteModel" gammaCategoryCount="4" shape="@gammaShape.s:firsthalf">
<parameter id="mutationRate.s:firsthalf" spec="parameter.RealParameter" estimate="false" lower="0.0" name="mutationRate">1.0</parameter>
<mutationRate idref="mutationRate.s:firsthalf"/>
<parameter id="proportionInvariant.s:firsthalf" spec="parameter.RealParameter" estimate="false" lower="0.0" name="proportionInvariant" upper="1.0">0.0</parameter>
<substModel id="hky.s:firsthalf" spec="HKY" kappa="@kappa.s:firsthalf">
<frequencies id="estimatedFreqs.s:firsthalf" spec="Frequencies" frequencies="@freqParameter.s:firsthalf"/>
Expand All @@ -121,7 +123,7 @@ spec="FilteredAlignment"
data="@Primates"
filter="450-898"/>
<siteModel id="SiteModel.s:secondhalf" spec="SiteModel" gammaCategoryCount="4" shape="@gammaShape.s:secondhalf">
<parameter id="mutationRate.s:secondhalf" spec="parameter.RealParameter" estimate="false" lower="0.0" name="mutationRate">1.0</parameter>
<mutationRate idref="mutationRate.s:secondhalf"/>
<parameter id="proportionInvariant.s:secondhalf" spec="parameter.RealParameter" estimate="false" lower="0.0" name="proportionInvariant" upper="1.0">0.0</parameter>
<substModel id="hky.s:secondhalf" spec="HKY" kappa="@kappa.s:secondhalf">
<frequencies id="estimatedFreqs.s:secondhalf" spec="Frequencies" frequencies="@freqParameter.s:secondhalf"/>
Expand All @@ -145,7 +147,12 @@ filter="450-898"/>
<operator id="YuleModelBICEPSEpochTop.t:tree" spec="EpochFlexOperator" scaleFactor="0.1" tree="@Tree.t:tree" weight="2.0"/>
<operator id="YuleModelBICEPSEpochAll.t:tree" spec="EpochFlexOperator" fromOldestTipOnly="false" scaleFactor="0.1" tree="@Tree.t:tree" weight="2.0"/>
<operator id="YuleModelBICEPSTreeFlex.t:tree" spec="TreeStretchOperator" scaleFactor="0.01" tree="@Tree.t:tree" weight="2.0"/>

<operator id="FixMeanMutationRatesOperator" spec="operator.kernel.BactrianDeltaExchangeOperator" delta="0.75" weight="2.0">
<parameter idref="mutationRate.s:firsthalf"/>
<parameter idref="mutationRate.s:secondhalf"/>
<weightvector id="weightparameter" spec="parameter.IntegerParameter" dimension="2" estimate="false" lower="0" upper="0">449 449</weightvector>
</operator>

<operator id="gammaShapeScaler.s:firsthalf" spec="AdaptableOperatorSampler" weight="0.05">
<parameter idref="gammaShape.s:firsthalf"/>
<operator id="AVMNOperator.firsthalf" spec="kernel.AdaptableVarianceMultivariateNormalOperator" allowNonsense="true" beta="0.05" burnin="400" initial="800" weight="0.1">
Expand Down Expand Up @@ -231,6 +238,8 @@ filter="450-898"/>

<log idref="MultiPartitionTreeLikelihood0"/>
<log idref="MultiPartitionTreeLikelihood"/>
<log idref="mutationRate.s:firsthalf"/>
<log idref="mutationRate.s:secondhalf"/>
</logger>
<logger id="screenlog" spec="Logger" logEvery="100">
<log idref="posterior"/>
Expand Down
3 changes: 2 additions & 1 deletion src/beastlabs/evolution/likelihood/BeagleDebugger.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ public class BeagleDebugger implements Beagle {

public boolean output = false;

public BeagleDebugger(Beagle beagle) {
public BeagleDebugger(Beagle beagle, boolean output) {
this.beagle = beagle;
this.output = output;
}

public void finalize() throws Throwable {
Expand Down
102 changes: 66 additions & 36 deletions src/beastlabs/evolution/likelihood/MultiPartitionTreeLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ public class MultiPartitionTreeLikelihood extends Distribution {
private static final boolean DEBUG = false;
static int instanceCount;


private double [][] lastKnownFrequencies;
private double [][] lastKnownCategorieWeights;
private double [][] lastKnownCategorieRates;

public int matrixUpdateCount;
public int partialUpdateCount;

public static boolean IS_MULTI_PARTITION_RECOMMENDED() {
if (!IS_MULTI_PARTITION_COMPATIBLE()) {
return false;
Expand Down Expand Up @@ -165,9 +173,9 @@ private static void fetchBeagleSettings() {
private static final int RESCALE_TIMES = 1;

// count the number of partial likelihood and matrix updates
private long totalMatrixUpdateCount = 0;
private long totalPartialsUpdateCount = 0;
private long totalEvaluationCount = 0;
public long totalMatrixUpdateCount = 0;
public long totalPartialsUpdateCount = 0;
public long totalEvaluationCount = 0;

private TreeInterface tree;
private BranchRateModel branchRateModel;
Expand Down Expand Up @@ -568,7 +576,7 @@ public void initialise(TreeInterface tree,
requirementFlags
);

// BeagleDebugger debugger = new BeagleDebugger(beagle);
// BeagleDebugger debugger = new BeagleDebugger(beagle, true);
// beagle = debugger;

InstanceDetails instanceDetails = beagle.getDetails();
Expand Down Expand Up @@ -686,6 +694,10 @@ public void initialise(TreeInterface tree,
throw new RuntimeException(mte.toString());
}

lastKnownFrequencies = new double[siteRateModels.size()][stateCount];
lastKnownCategorieWeights = new double[siteRateModels.size()][categoryCount];
lastKnownCategorieRates = new double[siteRateModels.size()*2][categoryCount];

instanceCount ++;
}

Expand Down Expand Up @@ -718,11 +730,11 @@ public int getTraitDim() {
// return RateRescalingScheme.NONE;
// }

private void updateSubstitutionModels(boolean... state) {
for (int i = 0; i < updateSubstitutionModels.length; i++) {
updateSubstitutionModels[i] = (state.length < 1 || state[0]);
}
}
// private void updateSubstitutionModels(boolean... state) {
// for (int i = 0; i < updateSubstitutionModels.length; i++) {
// updateSubstitutionModels[i] = (state.length < 1 || state[0]);
// }
// }

// private void updateSubstitutionModel(BranchRateModel branchModel) {
// for (int i = 0; i < branchModels.size(); i++) {
Expand All @@ -732,11 +744,11 @@ private void updateSubstitutionModels(boolean... state) {
// }
// }

private void updateSiteRateModels(boolean... state) {
for (int i = 0; i < updateSiteRateModels.length; i++) {
updateSiteRateModels[i] = (state.length < 1 || state[0]);
}
}
// private void updateSiteRateModels(boolean... state) {
// for (int i = 0; i < updateSiteRateModels.length; i++) {
// updateSiteRateModels[i] = (state.length < 1 || state[0]);
// }
// }

// private void updateSiteRateModel(SiteModel siteRateModel) {
// for (int i = 0; i < siteRateModels.size(); i++) {
Expand Down Expand Up @@ -954,7 +966,7 @@ public double calculateLogP() {
// final List<NodeOperation> nodeOperations = getNodeOperations();

if (COUNT_TOTAL_OPERATIONS) {
totalMatrixUpdateCount += branchOperations.size();
//totalMatrixUpdateCount += branchOperations.size();
//totalOperationCount += nodeOperations.size();
}

Expand Down Expand Up @@ -1078,7 +1090,10 @@ public double calculateLikelihood(List<BranchOperation> branchOperations, List<N
categoryRateBufferHelper[k].flipOffset(0);
}

beagle.setCategoryRatesWithIndex(categoryRateBufferHelper[k].getOffsetIndex(0), categoryRates);
if (changed(categoryRates, lastKnownCategorieRates[categoryRateBufferHelper[k].getOffsetIndex(0)])) {
beagle.setCategoryRatesWithIndex(categoryRateBufferHelper[k].getOffsetIndex(0), categoryRates);
System.arraycopy(categoryRates, 0, lastKnownCategorieRates[categoryRateBufferHelper[k].getOffsetIndex(0)], 0, categoryCount);
}
updatePartition[k] = true;
if (DEBUG) {
System.out.println("updateSiteRateModels, updatePartition["+k+"] = " + updatePartition[k]);
Expand All @@ -1103,7 +1118,7 @@ public double calculateLikelihood(List<BranchOperation> branchOperations, List<N
int [] probabilityIndices = new int [branchUpdateCount * partitionCount];
double[] edgeLengths = new double[branchUpdateCount * partitionCount];

int operationCount = 0;
matrixUpdateCount = 0;
int partition = 0;
for (SubstitutionModel evolutionaryProcessDelegate : substitutionModels) {
if (updatePartition[partition] || updateAllPartitions) {
Expand All @@ -1113,12 +1128,12 @@ public double calculateLikelihood(List<BranchOperation> branchOperations, List<N
}

for (int i = 0; i < branchUpdateCount; i++) {
eigenDecompositionIndices[operationCount] = eigenBufferHelper[partition].getOffsetIndex(0);
eigenDecompositionIndices[matrixUpdateCount] = eigenBufferHelper[partition].getOffsetIndex(0);
// = evolutionaryProcessDelegate.getEigenIndex(0);
categoryRateIndices[operationCount] = categoryRateBufferHelper[partition].getOffsetIndex(0);
probabilityIndices[operationCount] = /*evolutionaryProcessDelegate*/getMatrixIndex(branchUpdateIndices[i], partition);
edgeLengths[operationCount] = branchLengths[i];
operationCount++;
categoryRateIndices[matrixUpdateCount] = categoryRateBufferHelper[partition].getOffsetIndex(0);
probabilityIndices[matrixUpdateCount] = /*evolutionaryProcessDelegate*/getMatrixIndex(branchUpdateIndices[i], partition);
edgeLengths[matrixUpdateCount] = branchLengths[i];
matrixUpdateCount++;
}
}
partition++;
Expand All @@ -1131,10 +1146,10 @@ public double calculateLikelihood(List<BranchOperation> branchOperations, List<N
null, // firstDerivativeIndices
null, // secondDerivativeIndices
edgeLengths,
operationCount);
matrixUpdateCount);

if (COUNT_CALCULATIONS) {
totalMatrixUpdateCount += operationCount;
totalMatrixUpdateCount += matrixUpdateCount;
}

}
Expand All @@ -1157,7 +1172,7 @@ public double calculateLikelihood(List<BranchOperation> branchOperations, List<N
}


int operationCount = 0;
partialUpdateCount = 0;
k = 0;
for (NodeOperation op : nodeOperations) {
int nodeNum = op.getNodeNumber();
Expand Down Expand Up @@ -1235,17 +1250,17 @@ public double calculateLikelihood(List<BranchOperation> branchOperations, List<N
}

k += Beagle.PARTITION_OPERATION_TUPLE_SIZE;
operationCount++;
partialUpdateCount++;
}

}
}

beagle.updatePartialsByPartition(operations, operationCount);
beagle.updatePartialsByPartition(operations, partialUpdateCount);

if (COUNT_CALCULATIONS) {
totalEvaluationCount += 1;
totalPartialsUpdateCount += operationCount;
totalPartialsUpdateCount += partialUpdateCount;
}


Expand All @@ -1271,14 +1286,19 @@ public double calculateLikelihood(List<BranchOperation> branchOperations, List<N
// double[] scaleFactors = new double[totalPatternCount];
// beagle.getLogScaleFactors(cumulateScaleBufferIndex, scaleFactors);

// these could be set only when they change but store/restore would need to be considered
for (int i = 0; i < siteRateModels.size(); i++) {
double[] categoryWeights = this.siteRateModels.get(i).getCategoryProportions(null);
beagle.setCategoryWeights(i, categoryWeights);
if (changed(categoryWeights, lastKnownCategorieWeights[i])) {
beagle.setCategoryWeights(i, categoryWeights);
System.arraycopy(categoryWeights, 0, lastKnownCategorieWeights[i], 0, categoryWeights.length);
}

// This should probably explicitly be the state frequencies for the root node...
double[] frequencies = substitutionModels.get(i).getFrequencies();
beagle.setStateFrequencies(i, frequencies);
if (changed(frequencies, lastKnownFrequencies[i])) {
beagle.setStateFrequencies(i, frequencies);
System.arraycopy(frequencies, 0, lastKnownFrequencies[i], 0, frequencies.length);
}
}


Expand Down Expand Up @@ -1360,8 +1380,10 @@ public double calculateLikelihood(List<BranchOperation> branchOperations, List<N
// }
// beagle.getSiteLogLikelihoods(patternLogLikelihoods);

updateSubstitutionModels(false);
updateSiteRateModels(false);
// updateSubstitutionModels(false);
Arrays.fill(updateSubstitutionModels, false);
// updateSiteRateModels(false);
Arrays.fill(updateSiteRateModels, false);
updateAllPartitions = true;

if (Double.isNaN(tmpLogL) || Double.isInfinite(tmpLogL)) {
Expand Down Expand Up @@ -1470,7 +1492,16 @@ public double calculateLikelihood(List<BranchOperation> branchOperations, List<N
// }


private void updateSubstitutionModels(int partition, Beagle beagle, boolean flip) {
private boolean changed(double[] array, double[] lastKnown) {
for (int i = 0; i < array.length; i++) {
if (array[i] != lastKnown[i]) {
return true;
}
}
return false;
}

private void updateSubstitutionModels(int partition, Beagle beagle, boolean flip) {
if (flip) {
eigenBufferHelper[partition].flipOffset(0);
}
Expand Down Expand Up @@ -2030,8 +2061,7 @@ private boolean traverseLevelOrder(final Tree tree, final Node node,
int nodeNum = node.getNr();

// First update the transition probability matrix(ices) for this branch
if (needsUpdate ||
(node.getParent() != null && node.isDirty()!= Tree.IS_CLEAN)) { //updateNode[nodeNum]) {
if (node.getParent() != null && (needsUpdate || node.isDirty()!= Tree.IS_CLEAN)) { //updateNode[nodeNum]) {
// TODO - at the moment a matrix is updated even if a branch length doesn't change

// addBranchUpdateOperation(tree, node);
Expand Down
Loading

0 comments on commit ae11950

Please sign in to comment.