diff --git a/PTMCMCSampler/PTMCMCSampler.py b/PTMCMCSampler/PTMCMCSampler.py index 8063ddc..11f1530 100755 --- a/PTMCMCSampler/PTMCMCSampler.py +++ b/PTMCMCSampler/PTMCMCSampler.py @@ -1,6 +1,7 @@ import os import sys import time +import warnings import numpy as np @@ -69,6 +70,8 @@ class PTSampler(object): @param outDir: Full path to output directory for chain files (default = ./chains) @param verbose: Update current run-status to the screen (default=True) @param resume: Resume from a previous chain (still in testing so beware) (default=False) + @param ihcwr: Ignore hot chains when resuming. Note, hot chains will enter a burn-in + phase when resuming (default=False) """ @@ -89,6 +92,7 @@ def __init__( outDir="./chains", verbose=True, resume=False, + ihcwr=False, seed=None, ): # MPI initialization @@ -117,6 +121,7 @@ def __init__( self.outDir = outDir self.verbose = verbose self.resume = resume + self.ignoreHotChainsWhenResuming = ihcwr # setup output file if not os.path.exists(self.outDir): @@ -297,16 +302,23 @@ def initialize( except ValueError as error: print("Reading old chain files failed with error", error) raise Exception("Couldn't read old chain to resume") - self._chainfile = open(self.fname, "a") + if ( self.isave != self.thin and self.resumeLength % (self.isave / self.thin) != 1 # This special case is always OK ): # Initial sample plus blocks of isave/thin - raise Exception( - ( - "Old chain has {0} rows, which is not the initial sample plus a multiple of isave/thin = {1}" - ).format(self.resumeLength, self.isave // self.thin) - ) + if self.MPIrank != 0 and self.writeHotChains is False and self.ignoreHotChainsWhenResuming: + warnings.warn("Neglecting hot chains from the previous run. It is recommended to set writeHotChains=True when resuming a run with multiple temperatures.") + self._chainfile = open(self.fname, "w") + else: + raise Exception( + ( + "Old chain has {0} rows, which is not the initial sample plus a multiple of isave/thin = {1}" + ).format(self.resumeLength, self.isave // self.thin) + ) + else: + self._chainfile = open(self.fname, "a") + print( "Resuming with", self.resumeLength,