Skip to content

Commit

Permalink
Introduce experimental gradient accumulation API
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Jan 21, 2025
1 parent 2f20c94 commit d6bfdd1
Show file tree
Hide file tree
Showing 2 changed files with 436 additions and 0 deletions.
34 changes: 34 additions & 0 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
sys.path.append(parent_folder)
from utils.train_spmd_linear_model import train_and_evaluate

# CPU does not support optimization barriers, and hence we use this to disable
# the gradient checkpointing A/B test run for it.
SKIP_GRADIENT_CHECKPOINTING: bool = False


Expand Down Expand Up @@ -48,8 +50,40 @@ def test_basic(self):
baseline_losses, checkpointing_losses))


class TestSPMDLinearModelGradientAccumulation(
test_xla_sharding_base.XlaShardingTest):

def test_gradient_accumulation_matches(self):
"""Verify that gradient accumulation produces the same results and losses
with and without the XLA `While` op.
"""

COMMON_GRAD_ACC_ARGS = [
"--train_dataset_len", "65536", "--gradient_accumulation_steps", "8"
]
print('Training loop with traditional gradient accumulation')
with extended_argv(COMMON_GRAD_ACC_ARGS):
baseline_grad_acc_losses, baseline_grad_acc_result = train_and_evaluate()

print('Training loop with XLA\'s `While` gradient accumulation')
with extended_argv(COMMON_GRAD_ACC_ARGS +
["--use_gradient_accumulation_loop"]):
loop_grad_acc_losses, loop_grad_acc_result = train_and_evaluate()

# Verify that the model losses are not zero, and that the runs match.
assert all(loss != 0 for loss in baseline_grad_acc_losses)
assert all(
torch.allclose(baseline_loss, checkpointing_loss)
for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses,
loop_grad_acc_losses))
# Verify that the model produces non-zero outputs, and that the runs match.
assert not torch.any(baseline_grad_acc_result == 0)
assert torch.allclose(baseline_grad_acc_result, loop_grad_acc_result)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Relevant parser for the gradient checkpointing basic coverage.
parser.add_argument('--skip-gradient-checkpointing', action='store_true')
parsed_args, remaining_argv = parser.parse_known_args()
SKIP_GRADIENT_CHECKPOINTING = parsed_args.skip_gradient_checkpointing
Expand Down
Loading

0 comments on commit d6bfdd1

Please sign in to comment.