Skip to content

Commit

Permalink
Port main_function to TF2
Browse files Browse the repository at this point in the history
This patch ports the main_function library to TF2. Quite a bit of code
had to be adjusted to use newer APIs that are not incompatible with TF2.
Some functionality was completely removed like exporting graph defs as
it makes significantly less sense with TF2 if it is even supported.

Pull Request: google#285
  • Loading branch information
boomanaiden154 committed Jan 6, 2025
1 parent 19850c0 commit 07f1fad
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 323 deletions.
3 changes: 0 additions & 3 deletions gematria/model/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ gematria_py_test(
size = "small",
timeout = "moderate",
srcs = ["main_function_test.py"],
tags = [
"manual",
],
deps = [
":inference",
":main_function",
Expand Down
314 changes: 124 additions & 190 deletions gematria/model/python/main_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def main(_):
from gematria.utils.python import timer
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow as tf2

_ACTION = flags.DEFINE_enum_class(
'gematria_action',
Expand Down Expand Up @@ -199,11 +200,6 @@ def main(_):
20,
'The number of learned values to display for the list.',
)
_GEMATRIA_LOG_DEVICE_PLACEMENT = flags.DEFINE_bool(
'gematria_log_device_placement',
False,
'Print TensorFlow op placement to devices to the log.',
)
_GEMATRIA_RANDOM_SEED = flags.DEFINE_integer(
'gematria_random_seed',
123456789,
Expand Down Expand Up @@ -295,15 +291,15 @@ def main(_):
'',
'The directory to which the summaries from the training are stored.',
)
_GEMATRIA_SAVE_CHECKPOINT_SECS = flags.DEFINE_integer(
'gematria_save_checkpoint_secs',
60,
'The number of seconds of training after which a checkpoint is saved.',
_GEMATRIA_SAVE_CHECKPOINT_EPOCHS = flags.DEFINE_integer(
'gematria_save_checkpoint_epochs',
100,
'The number of epochs of training after which a checkpoint is saved.',
)
_GEMATRIA_SAVE_SUMMARIES_SECS = flags.DEFINE_integer(
'gematria_save_summaries_secs',
60,
'The number of seconds of training after which summaries are saved.',
_GEMATRIA_SAVE_SUMMARIES_EPOCHS = flags.DEFINE_integer(
'gematria_save_summaries_epochs',
100,
'The number of epochs of training after which summaries are saved.',
)
_GEMATRIA_EVAL_INTERVAL_SECS = flags.DEFINE_integer(
'gematria_eval_interval_secs',
Expand Down Expand Up @@ -505,79 +501,22 @@ def _resume_from_and_resume_to_dir_must_be_used_at_the_same_time(
return bool(resume_from_dir) == bool(resume_to_dir)


def _warmstart_from_file(scaffold: tf.train.Scaffold, sess):
def _warmstart_from_file(model: model_base.ModelBase):
"""Warmstarts the model from a specific checkpoint."""
del scaffold # Unused.
if not tf.io.gfile.exists(f'{_WARMSTART_FILE.value}.index'):
raise ValueError(f'No checkpoint was found at "{_WARMSTART_FILE.value}"')
training.partially_restore_from_checkpoint(
_WARMSTART_FILE.value, _LOAD_GLOBAL_STEP_FROM_CKPT.value, sess
_WARMSTART_FILE.value, _LOAD_GLOBAL_STEP_FROM_CKPT.value, model
)


def _warmstart_from_dir(scaffold: tf.train.Scaffold, sess):
def _warmstart_from_dir(model: model_base.ModelBase):
"""Warmstarts the model from the latest checkpoint in a directory."""
del scaffold # Unused.
checkpoint = tf.train.latest_checkpoint(_WARMSTART_DIR.value)
if not checkpoint:
raise ValueError(f'No checkpoint was found at "{_WARMSTART_DIR.value}"')
training.partially_restore_from_checkpoint(
checkpoint, _LOAD_GLOBAL_STEP_FROM_CKPT.value, sess
)


def _monitored_training_session_from_flags(
model: model_base.ModelBase, is_chief: bool
) -> tf.train.MonitoredTrainingSession:
"""Creates a monitored training session for 'model' from command-line flags.
Args:
model: The model for which the session is created.
is_chief: True when this is the chief training worker in a distributed
setup.
Returns:
The monitored training session object.
"""
hooks = []
if _GEMATRIA_TRAINING_NUM_EPOCHS.value > 0:
hooks.append(
tf.train.StopAtStepHook(last_step=_GEMATRIA_TRAINING_NUM_EPOCHS.value)
)
hooks += model.get_monitored_training_session_hooks()
session_config = tf.ConfigProto(
log_device_placement=_GEMATRIA_LOG_DEVICE_PLACEMENT.value
)
scaffold_init_fn = None
if _WARMSTART_FILE.value:
# If there is a checkpoint to bootstrap from, we add an init_fn to the
# monitored session that restores it. This init_fn is called only when an
# actual checkpoint is not available to fully restore the model.
scaffold_init_fn = _warmstart_from_file
elif _WARMSTART_DIR.value:
# If there is a directory to bootstrap from, we find the latest checkpoint
# in this directory and add an init_fn to the monitored session the same way
# as with _WARMSTART_FILE above.
scaffold_init_fn = _warmstart_from_dir

scaffold = tf.train.Scaffold(
init_fn=scaffold_init_fn,
saver=tf.train.Saver(
max_to_keep=_GEMATRIA_CHECKPOINT_MAX_TO_KEEP.value,
keep_checkpoint_every_n_hours=1,
),
)

return tf.train.MonitoredTrainingSession(
checkpoint_dir=_CHECKPOINT_DIR.value,
config=session_config,
scaffold=scaffold,
hooks=hooks,
is_chief=is_chief,
master=_MASTER.value,
save_checkpoint_secs=_GEMATRIA_SAVE_CHECKPOINT_SECS.value,
save_summaries_secs=_GEMATRIA_SAVE_SUMMARIES_SECS.value,
summary_dir=_GEMATRIA_SUMMARY_DIR.value,
checkpoint, _LOAD_GLOBAL_STEP_FROM_CKPT.value, model
)


Expand Down Expand Up @@ -751,12 +690,12 @@ def _extract_basic_blocks_with_throughput(
yield block


def _session_from_checkpoint(checkpoint_file: str) -> tf.Session:
"""Creates a local TF Session and restores it from a given checkpoint file."""
sess = tf.Session()
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
return sess
def _restore_model_from_checkpoint(
checkpoint_file: str, model: model_base.ModelBase
) -> None:
"""Restores a model from a checkpoint."""
checkpoint = tf2.train.Checkpoint(model)
checkpoint.restore(checkpoint_file)


def _task_names_from_command_line_flags() -> Sequence[str]:
Expand Down Expand Up @@ -793,116 +732,111 @@ def run_gematria_model_from_command_line_flags(
tf.random.set_random_seed(_GEMATRIA_RANDOM_SEED.value)
random.seed(_GEMATRIA_RANDOM_SEED.value)
is_chief = _GEMATRIA_TRAINING_TASK.value == 0
with tf.Graph().as_default():
dev = tf.train.replica_device_setter(
ps_tasks=_GEMATRIA_TRAINING_PS_TASKS.value
)
with tf.device(dev):
with timer.scoped('Creating model: ' + model_class.__name__):
num_replicas = _GEMATRIA_NUM_TRAINING_WORKER_REPLICAS.value
num_replicas_to_aggregate = (
_GEMATRIA_NUM_TRAINING_WORKER_REPLICAS_TO_AGGREGATE.value
)
model = model_class( # pytype: disable=wrong-arg-types
model_name=_MODEL_NAME.value,
task_list=_task_names_from_command_line_flags(),
synchronous_training=_GEMATRIA_SYNCHRONOUS_TRAINING.value,
loss_type=_LOSS_TYPE.value,
loss_normalization=_LOSS_NORMALIZATION.value,
trained_variable_groups=_TRAINED_VARIABLES.value,
learning_rate=_LEARNING_RATE.value,
decay_steps=_DECAY_STEPS.value,
decay_rate=_DECAY_RATE.value,
learning_rate_schedule=_LEARNING_RATE_SCHEDULE.value,
optimizer_type=_OPTIMIZER_TYPE.value,
grad_clip_norm=_GRAD_CLIP_NORM.value,
use_delta_loss=_GEMATRIA_USE_SEQ2SEQ_LOSS.value,
collected_percentile_ranks=tuple(
map(int, _COLLECTED_PERCENTILE_RANKS.value)
),
num_training_worker_replicas=num_replicas,
num_training_worker_replicas_to_aggregate=num_replicas_to_aggregate,
is_chief=is_chief,
**model_kwargs,
)
model.initialize()
with timer.scoped('Loading basic blocks'):
if _ACTION.value not in model_options.ACTIONS_WITHOUT_INPUT_DATA:
input_files = _INPUT_FILES.value
if not input_files:
sys.exit(
'At least one .tfrecord file must be specified through'
' --gematria_input_file.'
)
basic_block_protos = _make_basic_block_reader_from_command_line_flags(
input_files, _THROUGHPUT_SOURCE_FILTERS.value
)
blocks_with_throughput = _extract_basic_blocks_with_throughput(
model, basic_block_protos
)
else:
basic_block_protos = None
blocks_with_throughput = None
max_instructions_in_batch = _GEMATRIA_MAX_INSTRUCTIONS_IN_BATCH.value
if _ACTION.value == model_options.Action.EVAL:
session_hooks = None
model.run_continuous_evaluation(
tuple(blocks_with_throughput),
_CHECKPOINT_DIR.value,
_GEMATRIA_SUMMARY_DIR.value,
tf_master=_MASTER.value,
session_hooks=session_hooks,
eval_interval_seconds=_GEMATRIA_EVAL_INTERVAL_SECS.value,
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
)
elif _ACTION.value == model_options.Action.PREDICT:
with _session_from_checkpoint(_CHECKPOINT_FILE.value) as sess:
output_blocks = inference.predict_for_protos(
model,
sess,
basic_block_protos,
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
dev = tf.train.replica_device_setter(
ps_tasks=_GEMATRIA_TRAINING_PS_TASKS.value
)
with tf.device(dev):
with timer.scoped('Creating model: ' + model_class.__name__):
num_replicas = _GEMATRIA_NUM_TRAINING_WORKER_REPLICAS.value
num_replicas_to_aggregate = (
_GEMATRIA_NUM_TRAINING_WORKER_REPLICAS_TO_AGGREGATE.value
)
model = model_class( # pytype: disable=wrong-arg-types
model_name=_MODEL_NAME.value,
task_list=_task_names_from_command_line_flags(),
synchronous_training=_GEMATRIA_SYNCHRONOUS_TRAINING.value,
loss_type=_LOSS_TYPE.value,
loss_normalization=_LOSS_NORMALIZATION.value,
trained_variable_groups=_TRAINED_VARIABLES.value,
learning_rate=_LEARNING_RATE.value,
decay_steps=_DECAY_STEPS.value,
decay_rate=_DECAY_RATE.value,
learning_rate_schedule=_LEARNING_RATE_SCHEDULE.value,
optimizer_type=_OPTIMIZER_TYPE.value,
grad_clip_norm=_GRAD_CLIP_NORM.value,
use_delta_loss=_GEMATRIA_USE_SEQ2SEQ_LOSS.value,
collected_percentile_ranks=tuple(
map(int, _COLLECTED_PERCENTILE_RANKS.value)
),
num_training_worker_replicas=num_replicas,
num_training_worker_replicas_to_aggregate=num_replicas_to_aggregate,
is_chief=is_chief,
**model_kwargs,
)
model.initialize()
with timer.scoped('Loading basic blocks'):
if _ACTION.value not in model_options.ACTIONS_WITHOUT_INPUT_DATA:
input_files = _INPUT_FILES.value
if not input_files:
sys.exit(
'At least one .tfrecord file must be specified through'
' --gematria_input_file.'
)
tfrecord.write_protos(_GEMATRIA_OUTPUT_FILE.value, output_blocks)
elif _ACTION.value == model_options.Action.EXPORT_GRAPH_DEF:
graph_def = tf.get_default_graph().as_graph_def()
graph_def = tf.compat.v1.graph_util.remove_training_nodes(
graph_def, protected_nodes=model.output_tensor_names
basic_block_protos = _make_basic_block_reader_from_command_line_flags(
input_files, _THROUGHPUT_SOURCE_FILTERS.value
)
if _CHECKPOINT_FILE.value:
# When a checkpoint file is specified, replace tf.Variable nodes with
# tf.constant() nodes with the values of the variables from this
# checkpoint.
with _session_from_checkpoint(_CHECKPOINT_FILE.value) as sess:
graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=graph_def,
output_node_names=model.output_tensor_names,
)
tf.io.write_graph(
graph_def,
logdir=os.path.dirname(_GRAPH_DEF_FILE.value),
name=os.path.basename(_GRAPH_DEF_FILE.value),
blocks_with_throughput = _extract_basic_blocks_with_throughput(
model, basic_block_protos
)
elif _ACTION.value == model_options.Action.TRAIN:
if is_chief:
_resume_previous_experiment_if_needed()
with timer.scoped('Create training session'):
session = _monitored_training_session_from_flags(model, is_chief)
with timer.scoped('Running the training'):
with session:
randomize_expected_outputs = (
_TRAINING_THROUGHPUT_SELECTION.value
== io_options.ThroughputSelection.RANDOM
)
model.train(
session,
tuple(blocks_with_throughput),
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
num_epochs=_GEMATRIA_TRAINING_NUM_EPOCHS.value,
randomize_batches=_GEMATRIA_TRAINING_RANDOMIZE_BATCHES.value,
randomize_expected_outputs=randomize_expected_outputs,
)
else:
basic_block_protos = None
blocks_with_throughput = None
max_instructions_in_batch = _GEMATRIA_MAX_INSTRUCTIONS_IN_BATCH.value
if _ACTION.value == model_options.Action.EVAL:
session_hooks = None
model.run_continuous_evaluation(
tuple(blocks_with_throughput),
_CHECKPOINT_DIR.value,
_GEMATRIA_SUMMARY_DIR.value,
tf_master=_MASTER.value,
session_hooks=session_hooks,
eval_interval_seconds=_GEMATRIA_EVAL_INTERVAL_SECS.value,
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
)
elif _ACTION.value == model_options.Action.PREDICT:
_restore_model_from_checkpoint(_CHECKPOINT_FILE.value, model)
output_blocks = inference.predict_for_protos(
model,
basic_block_protos,
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
)
tfrecord.write_protos(_GEMATRIA_OUTPUT_FILE.value, output_blocks)
elif _ACTION.value == model_options.Action.TRAIN:
if is_chief:
_resume_previous_experiment_if_needed()
if _WARMSTART_FILE.value:
# If there is a checkpoint to bootstrap from, we add an init_fn to the
# monitored session that restores it. This init_fn is called only when an
# actual checkpoint is not available to fully restore the model.
_warmstart_from_file(_WARMSTART_FILE.value)
elif _WARMSTART_DIR.value:
# If there is a directory to bootstrap from, we find the latest checkpoint
# in this directory and add an init_fn to the monitored session the same way
# as with _WARMSTART_FILE above.
_warmstart_from_dir(_WARMSTART_DIR.value)
randomize_expected_outputs = (
_TRAINING_THROUGHPUT_SELECTION.value
== io_options.ThroughputSelection.RANDOM
)

checkpoint = tf2.train.Checkpoint(model)
checkpoint_manager = tf2.train.CheckpointManager(
checkpoint,
_CHECKPOINT_DIR.value,
_GEMATRIA_CHECKPOINT_MAX_TO_KEEP.value,
)

def checkpoint_model():
checkpoint_manager.save()

model.train(
tuple(blocks_with_throughput),
max_blocks_in_batch=_GEMATRIA_MAX_BLOCKS_IN_BATCH.value,
max_instructions_in_batch=max_instructions_in_batch,
num_epochs=_GEMATRIA_TRAINING_NUM_EPOCHS.value,
randomize_batches=_GEMATRIA_TRAINING_RANDOMIZE_BATCHES.value,
randomize_expected_outputs=randomize_expected_outputs,
hooks=[(_GEMATRIA_SAVE_CHECKPOINT_EPOCHS.value, checkpoint_model)],
)
Loading

0 comments on commit 07f1fad

Please sign in to comment.