Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resuming a calculation with hot chains (-nc) #54

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions PTMCMCSampler/PTMCMCSampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import time
import warnings

import numpy as np

Expand Down Expand Up @@ -69,6 +70,8 @@ class PTSampler(object):
@param outDir: Full path to output directory for chain files (default = ./chains)
@param verbose: Update current run-status to the screen (default=True)
@param resume: Resume from a previous chain (still in testing so beware) (default=False)
@param ihcwr: Ignore hot chains when resuming. Note, hot chains will enter a burn-in
phase when resuming (default=False)

"""

Expand All @@ -89,6 +92,7 @@ def __init__(
outDir="./chains",
verbose=True,
resume=False,
ihcwr=False,
seed=None,
):
# MPI initialization
Expand Down Expand Up @@ -117,6 +121,7 @@ def __init__(
self.outDir = outDir
self.verbose = verbose
self.resume = resume
self.ignoreHotChainsWhenResuming = ihcwr

# setup output file
if not os.path.exists(self.outDir):
Expand Down Expand Up @@ -297,16 +302,23 @@ def initialize(
except ValueError as error:
print("Reading old chain files failed with error", error)
raise Exception("Couldn't read old chain to resume")
self._chainfile = open(self.fname, "a")
if (
self.isave != self.thin
and self.resumeLength % (self.isave / self.thin) != 1 # This special case is always OK
): # Initial sample plus blocks of isave/thin
raise Exception(
(
"Old chain has {0} rows, which is not the initial sample plus a multiple of isave/thin = {1}"
).format(self.resumeLength, self.isave // self.thin)
)
if self.MPIrank != 0 and self.writeHotChains is False and self.ignoreHotChainsWhenResuming:
warnings.warn("Neglecting hot chains from the previous run. It is recommended to set writeHotChains=True when resuming a run with multiple temperatures.")
self._chainfile = open(self.fname, "w")
else:
raise Exception(
(
"Old chain has {0} rows, which is not the initial sample plus a multiple of isave/thin = {1}"
).format(self.resumeLength, self.isave // self.thin)
)
else:
self._chainfile = open(self.fname, "a")

print(
"Resuming with",
self.resumeLength,
Expand Down