diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 57a7ac68f..12d88f86f 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: def create_joint_forward_backward(fn): def joint_forward_backward( - primals: List[Any], tangents: List[Any] + primals: List[Any], cotangents: List[Any] ) -> Tuple[List[Any], List[Any]]: # Call the forward pass outs = fn(*primals) @@ -68,21 +68,21 @@ def joint_forward_backward( grad_primals.append(p) # Get the outputs that need gradients - assert len(tangents) == len(outs) + assert len(cotangents) == len(outs) needed_outs = [] - needed_tangents = [] - for out, tangent in zip(outs, tangents): + needed_cotangents = [] + for out, cotangent in zip(outs, cotangents): if isinstance(out, Tensor) and out.requires_grad: needed_outs.append(out) - needed_tangents.append(tangent) + needed_cotangents.append(cotangent) backward_out = [] # Call the backwards pass if grad_primals: backward_out = torch.autograd.grad( needed_outs, grad_primals, - grad_outputs=needed_tangents, - allow_unused=True, + grad_outputs=needed_cotangents, + allow_unused=True ) backward_out_iter = iter(backward_out) return outs, [ @@ -140,12 +140,13 @@ def create_aot_autograd_function( compiled_fw = None compiled_bw = None num_outs = None + aot_decompositions = {**aot_autograd_decompositions, **decompositions} class CompiledFunction(torch.autograd.Function): @staticmethod @disable_torchdynamo def forward(ctx, *flat_tensor_args): - nonlocal compiled_fw, compiled_bw, num_outs + nonlocal compiled_fw, num_outs if compiled_fw is None: with torch.set_grad_enabled(grad_state): out = flat_fn(*flat_tensor_args) @@ -159,31 +160,67 @@ def forward(ctx, *flat_tensor_args): num_outs = 1 joint_inputs = (flat_tensor_args, out) - aot_decompositions = {**aot_autograd_decompositions, **decompositions} + # Need it because autograd.Function disables grad in forward with torch.set_grad_enabled(grad_state): fx_g = make_fx(joint_forward_backward, aot_decompositions)( *joint_inputs ) fw_module, bw_module = partition_fn(fx_g, joint_inputs) - # print(fw_module.code, bw_module.code) compiled_fw = fw_compiler(fw_module, flat_tensor_args) fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) - - bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] - compiled_bw = bw_compiler(bw_module, bw_args) + if partition_fn is default_partition: + ctx.num_intermediate = len(fw_outs[num_outs:]) + ctx.num_inputs = len(flat_tensor_args) + to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out + ctx.fx_g = fx_g + ctx.save_for_backward(*to_be_saved) + ctx.fwd_graph = fw_module.code + ctx.bw_graph = bw_module.code + else: + nonlocal compiled_bw + bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] + compiled_bw = bw_compiler(bw_module, bw_args) + ctx.save_for_backward(*fw_outs[num_outs:]) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) - ctx.save_for_backward(*fw_outs[num_outs:]) + if partition_fn is default_partition: + with torch.set_grad_enabled(grad_state): + out = flat_fn(*flat_tensor_args) + out = pytree.tree_map( + lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out + ) + ctx.num_intermediate = len(fw_outs[num_outs:]) + ctx.num_inputs = len(flat_tensor_args) + to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out + ctx.save_for_backward(*to_be_saved) + else: + ctx.save_for_backward(*fw_outs[num_outs:]) return tuple(fw_outs[0:num_outs]) @staticmethod @disable_torchdynamo - def backward(ctx, *flat_args): - contiguous_args = [t.contiguous() for t in flat_args] - # contiguous_args = [t for t in flat_args] - out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) - return tuple(out) + def backward(ctx, *flat_grad_outs): + contiguous_args = [t.contiguous() for t in flat_grad_outs] + if compiled_bw is None: + assert partition_fn is default_partition + with torch.set_grad_enabled(grad_state): + inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs] + fx_g = make_fx(joint_forward_backward, aot_decompositions)(inputs, contiguous_args) + fw_module, bw_module = partition_fn(fx_g, ctx.saved_tensors[ctx.num_intermediate:]) + assert fx_g.code == ctx.fx_g.code + f = aot_function(bw_module, bw_compiler, bw_compiler, partition_fn, aot_decompositions) + print("INPUTS----->", *ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) + print(bw_module.code) + out = f(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) + return out + else: + if partition_fn is default_partition: + out = normalize_as_list(compiled_bw(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)) + else: + assert not torch.is_grad_enabled() + out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) + return tuple(out) return CompiledFunction diff --git a/functorch/_src/partitioners.py b/functorch/_src/partitioners.py index 550e2b7a4..755502f9c 100644 --- a/functorch/_src/partitioners.py +++ b/functorch/_src/partitioners.py @@ -153,7 +153,7 @@ def default_partition( saved_values.append(user) else: saved_values.append(node) - saved_values = list(set(saved_values)) + saved_values = list(saved_values) return _extract_fwd_bwd_modules(joint_module, saved_values) diff --git a/test/test_pythonkey.py b/test/test_pythonkey.py index ae399fc81..7ec2868da 100644 --- a/test/test_pythonkey.py +++ b/test/test_pythonkey.py @@ -246,14 +246,45 @@ def f(args, kwargs): def _outs_and_grads(fn, inps): outs = fn(*inps) + diff_outs = [] for out in pytree.tree_flatten(outs)[0]: if isinstance(out, torch.Tensor) and out.requires_grad: - out.sum().backward(retain_graph=True) - grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]] - for inp in pytree.tree_flatten(inps)[0]: - inp.grad = None + diff_outs.append(out) + def full_reduce(outs): + res = 0 + for out in outs: + res=res+out.sum() + return res + # print(inps) + grads = torch.autograd.grad(full_reduce(diff_outs), pytree.tree_flatten(inps)[0], create_graph=True) return outs, grads +def _outs_and_grads_and_grad_grads(fn, inps): + outs = fn(*inps) + diff_outs = [] + diff_inps = [] + for out in pytree.tree_flatten(outs)[0]: + if isinstance(out, torch.Tensor) and out.requires_grad: + diff_outs.append(out) + for inp in pytree.tree_flatten(inps)[0]: + if isinstance(inp, torch.Tensor) and inp.requires_grad: + diff_inps.append(inp) + def full_reduce(outs): + res = 0 + # print("entering full_reduce: ", type(outs)) + for out in outs: + res=res+out.sum() + return res + print("diff_outs, diff_inps: ", diff_outs, diff_inps) + grads = torch.autograd.grad(diff_outs, diff_inps, create_graph=True) + # print("grad call with: ", full_reduce(diff_outs), diff_inps) + diff_grads = [] + for grad_ in grads: + if isinstance(grad_, torch.Tensor) and grad_.requires_grad: + diff_grads.append(grad_) + # print("grad grad call with: ", grads, full_reduce(diff_grads), diff_inps) + grad_grads = torch.autograd.grad(diff_grads, diff_inps) + return outs, grads, grad_grads class TestAOTAutograd(TestCase): def verify_aot_autograd(self, f, inp): @@ -261,10 +292,11 @@ def verify_aot_autograd(self, f, inp): compiled_f = aot_module(f, nop) else: compiled_f = aot_function(f, nop) - ref_out, ref_grad = _outs_and_grads(f, inp) - test_out, test_grad = _outs_and_grads(compiled_f, inp) + ref_out, ref_grad, ref_grad_grad = _outs_and_grads_and_grad_grads(f, inp) + test_out, test_grad, test_grad_grad = _outs_and_grads_and_grad_grads(compiled_f, inp) self.assertEqual(ref_out, test_out) self.assertEqual(ref_grad, test_grad) + self.assertEqual(ref_grad_grad, test_grad_grad) def test_single_output(self): def f(a, b): @@ -284,6 +316,12 @@ def f(a, b): inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] self.verify_aot_autograd(f, inp) + def test_cube(self): + def f(a): + return a ** 3 + inp = [torch.tensor(2.3, requires_grad=True)] + self.verify_aot_autograd(f, inp) + def test_no_grad_input_output(self): def f(a, b): return a.cos(), b.cos(), a * b