Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port main_function to TF2 #285

Open
wants to merge 2 commits into
base: users/boomanaiden154/main.port-main_function-to-tf2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions gematria/model/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ gematria_py_test(
size = "small",
timeout = "moderate",
srcs = ["main_function_test.py"],
tags = [
"manual",
],
deps = [
":inference",
":main_function",
Expand Down
314 changes: 124 additions & 190 deletions gematria/model/python/main_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def main(_):
from gematria.utils.python import timer
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow as tf2

_ACTION = flags.DEFINE_enum_class(
'gematria_action',
Expand Down Expand Up @@ -199,11 +200,6 @@ def main(_):
20,
'The number of learned values to display for the list.',
)
_GEMATRIA_LOG_DEVICE_PLACEMENT = flags.DEFINE_bool(
'gematria_log_device_placement',
False,
'Print TensorFlow op placement to devices to the log.',
)
_GEMATRIA_RANDOM_SEED = flags.DEFINE_integer(
'gematria_random_seed',
123456789,
Expand Down Expand Up @@ -295,15 +291,15 @@ def main(_):
'',
'The directory to which the summaries from the training are stored.',
)
_GEMATRIA_SAVE_CHECKPOINT_SECS = flags.DEFINE_integer(
'gematria_save_checkpoint_secs',
60,
'The number of seconds of training after which a checkpoint is saved.',
_GEMATRIA_SAVE_CHECKPOINT_EPOCHS = flags.DEFINE_integer(
'gematria_save_checkpoint_epochs',
100,
'The number of epochs of training after which a checkpoint is saved.',
)
_GEMATRIA_SAVE_SUMMARIES_SECS = flags.DEFINE_integer(
'gematria_save_summaries_secs',
60,
'The number of seconds of training after which summaries are saved.',
_GEMATRIA_SAVE_SUMMARIES_EPOCHS = flags.DEFINE_integer(
'gematria_save_summaries_epochs',
100,
'The number of epochs of training after which summaries are saved.',
)
_GEMATRIA_EVAL_INTERVAL_SECS = flags.DEFINE_integer(
'gematria_eval_interval_secs',
Expand Down Expand Up @@ -505,79 +501,22 @@ def _resume_from_and_resume_to_dir_must_be_used_at_the_same_time(
return bool(resume_from_dir) == bool(resume_to_dir)


def _warmstart_from_file(scaffold: tf.train.Scaffold, sess):
def _warmstart_from_file(model: model_base.ModelBase):
"""Warmstarts the model from a specific checkpoint."""
del scaffold # Unused.
if not tf.io.gfile.exists(f'{_WARMSTART_FILE.value}.index'):
raise ValueError(f'No checkpoint was found at "{_WARMSTART_FILE.value}"')
training.partially_restore_from_checkpoint(
_WARMSTART_FILE.value, _LOAD_GLOBAL_STEP_FROM_CKPT.value, sess
_WARMSTART_FILE.value, _LOAD_GLOBAL_STEP_FROM_CKPT.value, model
)


def _warmstart_from_dir(scaffold: tf.train.Scaffold, sess):
def _warmstart_from_dir(model: model_base.ModelBase):
"""Warmstarts the model from the latest checkpoint in a directory."""
del scaffold # Unused.
checkpoint = tf.train.latest_checkpoint(_WARMSTART_DIR.value)
if not checkpoint:
raise ValueError(f'No checkpoint was found at "{_WARMSTART_DIR.value}"')
training.partially_restore_from_checkpoint(
checkpoint, _LOAD_GLOBAL_STEP_FROM_CKPT.value, sess
)


def _monitored_training_session_from_flags(
model: model_base.ModelBase, is_chief: bool
) -> tf.train.MonitoredTrainingSession:
"""Creates a monitored training session for 'model' from command-line flags.

Args:
model: The model for which the session is created.
is_chief: True when this is the chief training worker in a distributed
setup.

Returns:
The monitored training session object.
"""
hooks = []
if _GEMATRIA_TRAINING_NUM_EPOCHS.value > 0:
hooks.append(
tf.train.StopAtStepHook(last_step=_GEMATRIA_TRAINING_NUM_EPOCHS.value)
)
hooks += model.get_monitored_training_session_hooks()
session_config = tf.ConfigProto(
log_device_placement=_GEMATRIA_LOG_DEVICE_PLACEMENT.value
)
scaffold_init_fn = None
if _WARMSTART_FILE.value:
# If there is a checkpoint to bootstrap from, we add an init_fn to the
# monitored session that restores it. This init_fn is called only when an
# actual checkpoint is not available to fully restore the model.
scaffold_init_fn = _warmstart_from_file
elif _WARMSTART_DIR.value:
# If there is a directory to bootstrap from, we find the latest checkpoint
# in this directory and add an init_fn to the monitored session the same way
# as with _WARMSTART_FILE above.
scaffold_init_fn = _warmstart_from_dir

scaffold = tf.train.Scaffold(
init_fn=scaffold_init_fn,
saver=tf.train.Saver(
max_to_keep=_GEMATRIA_CHECKPOINT_MAX_TO_KEEP.value,
keep_checkpoint_every_n_hours=1,
),
)

return tf.train.MonitoredTrainingSession(
checkpoint_dir=_CHECKPOINT_DIR.value,
config=session_config,
scaffold=scaffold,
hooks=hooks,
is_chief=is_chief,
master=_MASTER.value,
save_checkpoint_secs=_GEMATRIA_SAVE_CHECKPOINT_SECS.value,
save_summaries_secs=_GEMATRIA_SAVE_SUMMARIES_SECS.value,
summary_dir=_GEMATRIA_SUMMARY_DIR.value,
checkpoint, _LOAD_GLOBAL_STEP_FROM_CKPT.value, model
)


Expand Down Expand Up @@ -751,12 +690,12 @@ def _extract_basic_blocks_with_throughput(
yield block


def _session_from_checkpoint(checkpoint_file: str) -> tf.Session:
"""Creates a local TF Session and restores it from a given checkpoint file."""
sess = tf.Session()
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
return sess
def _restore_model_from_checkpoint(
checkpoint_file: str, model: model_base.ModelBase
) -> None:
"""Restores a model from a checkpoint."""
checkpoint = tf2.train.Checkpoint(model)
checkpoint.restore(checkpoint_file)


def _task_names_from_command_line_flags() -> Sequence[str]:
Expand Down Expand Up @@ -793,116 +732,111 @@ def run_gematria_model_from_command_line_flags(
tf.random.set_random_seed(_GEMATRIA_RANDOM_SEED.value)
random.seed(_GEMATRIA_RANDOM_SEED.value)
is_chief = _GEMATRIA_TRAINING_TASK.value == 0
with tf.Graph().as_default():
dev = tf.train.replica_device_setter(
ps_tasks=_GEMATRIA_TRAINING_PS_TASKS.value
)
with tf.device(dev):
with timer.scoped('Creating model: ' + model_class.__name__):
num_replicas = _GEMATRIA_NUM_TRAINING_WORKER_REPLICAS.value
num_replicas_to_aggregate = (
_GEMATRIA_NUM_TRAINING_WORKER_REPLICAS_TO_AGGREGATE.value
)
model = model_class( # pytype: disable=wrong-arg-types
model_name=_MODEL_NAME.value,
task_list=_task_names_from_command_line_flags(),
synchronous_training=_GEMATRIA_SYNCHRONOUS_TRAINING.value,
loss_type=_LOSS_TYPE.value,
loss_normalization=_LOSS_NORMALIZATION.value,
trained_variable_groups=_TRAINED_VARIABLES.value,
learning_rate=_LEARNING_RATE.value,
decay_steps=_DECAY_STEPS.value,
decay_rate=_DECAY_RATE.value,
learning_rate_schedule=_LEARNING_RATE_SCHEDULE.value,
optimizer_type=_OPTIMIZER_TYPE.value,
grad_clip_norm=_GRAD_CLIP_NORM.value,
use_delta_loss=_GEMATRIA_USE_SEQ2SEQ_LOSS.value,
collected_percentile_ranks=tuple(
map(int, _COLLECTED_PERCENTILE_RANKS.value)
),
num_training_worker_replicas=num_replicas,
num_training_worker_replicas_to_aggregate=num_replicas_to_aggregate,
is_chief=is_chief,
**model_kwargs,
)
model.initialize()
with timer.scoped('Loading basic blocks'):
if _ACTION.value not in model_options.ACTIONS_WITHOUT_INPUT_DATA:
input_files = _INPUT_FILES.value
if not input_files:
sys.exit(
'At least one .tfrecord file must be specified through'
' --gematria_input_file.'
)
basic_block_protos = _make_basic_block_reader_from_command_line_flags(
input_files, _THROUGHPUT_SOURCE_FILTERS.value
)
blocks_with_throughput = _extract_basic_blocks_with_throughput(
model, basic_block_protos
)
else:
basic_block_protos = None
blocks_with_throughput = None
max_instructions_in_batch = _GEMATRIA_MAX_INSTRUCTIONS_IN_BATCH.value
if _ACTION.value == model_options.Action.EVAL:
session_hooks = None
model.run_continuous_evaluation(
tuple(blocks_with_throughput),
_CHECKPOINT_DIR.value,
_GEMATRIA_SUMMARY_DIR.value,
tf_master=_MASTER.value,
session_hooks=session_hooks,
eval_interval_seconds=_GEMATRIA_EVAL_INTERVAL_SECS.value,
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
)
elif _ACTION.value == model_options.Action.PREDICT:
with _session_from_checkpoint(_CHECKPOINT_FILE.value) as sess:
output_blocks = inference.predict_for_protos(
model,
sess,
basic_block_protos,
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
dev = tf.train.replica_device_setter(
ps_tasks=_GEMATRIA_TRAINING_PS_TASKS.value
)
with tf.device(dev):
with timer.scoped('Creating model: ' + model_class.__name__):
num_replicas = _GEMATRIA_NUM_TRAINING_WORKER_REPLICAS.value
num_replicas_to_aggregate = (
_GEMATRIA_NUM_TRAINING_WORKER_REPLICAS_TO_AGGREGATE.value
)
model = model_class( # pytype: disable=wrong-arg-types
model_name=_MODEL_NAME.value,
task_list=_task_names_from_command_line_flags(),
synchronous_training=_GEMATRIA_SYNCHRONOUS_TRAINING.value,
loss_type=_LOSS_TYPE.value,
loss_normalization=_LOSS_NORMALIZATION.value,
trained_variable_groups=_TRAINED_VARIABLES.value,
learning_rate=_LEARNING_RATE.value,
decay_steps=_DECAY_STEPS.value,
decay_rate=_DECAY_RATE.value,
learning_rate_schedule=_LEARNING_RATE_SCHEDULE.value,
optimizer_type=_OPTIMIZER_TYPE.value,
grad_clip_norm=_GRAD_CLIP_NORM.value,
use_delta_loss=_GEMATRIA_USE_SEQ2SEQ_LOSS.value,
collected_percentile_ranks=tuple(
map(int, _COLLECTED_PERCENTILE_RANKS.value)
),
num_training_worker_replicas=num_replicas,
num_training_worker_replicas_to_aggregate=num_replicas_to_aggregate,
is_chief=is_chief,
**model_kwargs,
)
model.initialize()
with timer.scoped('Loading basic blocks'):
if _ACTION.value not in model_options.ACTIONS_WITHOUT_INPUT_DATA:
input_files = _INPUT_FILES.value
if not input_files:
sys.exit(
'At least one .tfrecord file must be specified through'
' --gematria_input_file.'
)
tfrecord.write_protos(_GEMATRIA_OUTPUT_FILE.value, output_blocks)
elif _ACTION.value == model_options.Action.EXPORT_GRAPH_DEF:
graph_def = tf.get_default_graph().as_graph_def()
graph_def = tf.compat.v1.graph_util.remove_training_nodes(
graph_def, protected_nodes=model.output_tensor_names
basic_block_protos = _make_basic_block_reader_from_command_line_flags(
input_files, _THROUGHPUT_SOURCE_FILTERS.value
)
if _CHECKPOINT_FILE.value:
# When a checkpoint file is specified, replace tf.Variable nodes with
# tf.constant() nodes with the values of the variables from this
# checkpoint.
with _session_from_checkpoint(_CHECKPOINT_FILE.value) as sess:
graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=graph_def,
output_node_names=model.output_tensor_names,
)
tf.io.write_graph(
graph_def,
logdir=os.path.dirname(_GRAPH_DEF_FILE.value),
name=os.path.basename(_GRAPH_DEF_FILE.value),
blocks_with_throughput = _extract_basic_blocks_with_throughput(
model, basic_block_protos
)
elif _ACTION.value == model_options.Action.TRAIN:
if is_chief:
_resume_previous_experiment_if_needed()
with timer.scoped('Create training session'):
session = _monitored_training_session_from_flags(model, is_chief)
with timer.scoped('Running the training'):
with session:
randomize_expected_outputs = (
_TRAINING_THROUGHPUT_SELECTION.value
== io_options.ThroughputSelection.RANDOM
)
model.train(
session,
tuple(blocks_with_throughput),
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
num_epochs=_GEMATRIA_TRAINING_NUM_EPOCHS.value,
randomize_batches=_GEMATRIA_TRAINING_RANDOMIZE_BATCHES.value,
randomize_expected_outputs=randomize_expected_outputs,
)
else:
basic_block_protos = None
blocks_with_throughput = None
max_instructions_in_batch = _GEMATRIA_MAX_INSTRUCTIONS_IN_BATCH.value
if _ACTION.value == model_options.Action.EVAL:
session_hooks = None
model.run_continuous_evaluation(
tuple(blocks_with_throughput),
_CHECKPOINT_DIR.value,
_GEMATRIA_SUMMARY_DIR.value,
tf_master=_MASTER.value,
session_hooks=session_hooks,
eval_interval_seconds=_GEMATRIA_EVAL_INTERVAL_SECS.value,
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
)
elif _ACTION.value == model_options.Action.PREDICT:
_restore_model_from_checkpoint(_CHECKPOINT_FILE.value, model)
output_blocks = inference.predict_for_protos(
model,
basic_block_protos,
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
)
tfrecord.write_protos(_GEMATRIA_OUTPUT_FILE.value, output_blocks)
elif _ACTION.value == model_options.Action.TRAIN:
if is_chief:
_resume_previous_experiment_if_needed()
if _WARMSTART_FILE.value:
# If there is a checkpoint to bootstrap from, we add an init_fn to the
# monitored session that restores it. This init_fn is called only when an
# actual checkpoint is not available to fully restore the model.
_warmstart_from_file(_WARMSTART_FILE.value)
elif _WARMSTART_DIR.value:
# If there is a directory to bootstrap from, we find the latest checkpoint
# in this directory and add an init_fn to the monitored session the same way
# as with _WARMSTART_FILE above.
_warmstart_from_dir(_WARMSTART_DIR.value)
randomize_expected_outputs = (
_TRAINING_THROUGHPUT_SELECTION.value
== io_options.ThroughputSelection.RANDOM
)

checkpoint = tf2.train.Checkpoint(model)
checkpoint_manager = tf2.train.CheckpointManager(
checkpoint,
_CHECKPOINT_DIR.value,
_GEMATRIA_CHECKPOINT_MAX_TO_KEEP.value,
)

def checkpoint_model():
checkpoint_manager.save()

model.train(
tuple(blocks_with_throughput),
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
num_epochs=_GEMATRIA_TRAINING_NUM_EPOCHS.value,
randomize_batches=_GEMATRIA_TRAINING_RANDOMIZE_BATCHES.value,
randomize_expected_outputs=randomize_expected_outputs,
hooks=[(_GEMATRIA_SAVE_CHECKPOINT_EPOCHS.value, checkpoint_model)],
)
Loading
Loading