Skip to content

Commit

Permalink
SAC API follow ups to restore old behavior (pytorch#401)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanchaol authored Jun 13, 2024
1 parent 0bf344c commit 230300b
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,26 @@
# currently selective per op and per layer checkpointing are supported
def checkpoint_wrapper(module, config):
if config.mode == "selective" and config.selective_ac_option == "op":
from torch.utils.checkpoint import create_selective_checkpoint_contexts
from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts,
)

def _get_custom_policy(meta):
def _custom_policy(mode, func, *args, **kwargs):
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
return func in no_recompute_list and not (
to_save = func in no_recompute_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
)
return (
CheckpointPolicy.MUST_SAVE
if to_save
else CheckpointPolicy.PREFER_RECOMPUTE
)

return _custom_policy

Expand Down

0 comments on commit 230300b

Please sign in to comment.