diff --git a/functorch/_src/eager_transforms.py b/functorch/_src/eager_transforms.py index b16ae97c5..29c47c8c0 100644 --- a/functorch/_src/eager_transforms.py +++ b/functorch/_src/eager_transforms.py @@ -32,6 +32,7 @@ _func_increment_nesting, _assert_wrapped_functional, _propagate_functional_input_mutation, + set_inplace_requires_grad_allowed, ) argnums_t = Union[int, Tuple[int, ...]] @@ -40,7 +41,12 @@ def _create_differentiable(inps, level=None): def create_differentiable(x): if isinstance(x, torch.Tensor): - return x.requires_grad_() + try: + set_inplace_requires_grad_allowed(True) + return x.requires_grad_() + finally: + set_inplace_requires_grad_allowed(False) + raise ValueError(f'Thing passed to transform API must be Tensor, ' f'got {type(x)}') return tree_map(create_differentiable, inps) diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index a2c7f88b0..fcde68e46 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -102,13 +102,26 @@ class FuncTorchTLS : public FuncTorchTLSBase { } void checkSupportsInplaceRequiresGrad() const override { - // Does nothing + TORCH_CHECK(dynamicLayerStack.size() == 0 || allow_inplace_requires_grad_, + "You are attempting to call Tensor.requires_grad_() (or perhaps using ", + "torch.autograd.functional.* APIs) inside of a function being transformed ", + "by a functorch transform. ", + "This is unsupported, please attempt to use the functorch transforms ", + "(e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() " + "outside of a function being transformed instead."); } void checkSupportsRetainGrad() const override { - // Does nothing + TORCH_CHECK(dynamicLayerStack.size() == 0, + "You are attempting to call Tensor.retain_grad() ", + "inside of a function being transformed ", + "by a functorch transform. ", + "This is unsupported, please attempt to use the functorch transforms ", + "(e.g. grad, vjp, jacrev, jacfwd, hessian) or call retain_grad() " + "outside of a function being transformed instead."); } std::vector dynamicLayerStack; + bool allow_inplace_requires_grad_ = false; }; static FuncTorchTLS* getRawFunctorchTLS() { @@ -122,6 +135,12 @@ static FuncTorchTLS* getRawFunctorchTLS() { return result; } +void setInplaceRequiresGradAllowed(bool allowed) { + auto* functorch_tls = getRawFunctorchTLS(); + functorch_tls->allow_inplace_requires_grad_ = allowed; +} + + static std::vector& dynamicLayerStackAccessor() { return getRawFunctorchTLS()->dynamicLayerStack; } diff --git a/functorch/csrc/DynamicLayer.h b/functorch/csrc/DynamicLayer.h index cf84311b4..fe912980c 100644 --- a/functorch/csrc/DynamicLayer.h +++ b/functorch/csrc/DynamicLayer.h @@ -85,6 +85,8 @@ Tensor unwrapIfDead(const Tensor& tensor); std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); std::ostream& operator<<(std::ostream& os, const std::vector& dynamicLayerStack); +void setInplaceRequiresGradAllowed(bool allowed); + } } // namespace at diff --git a/functorch/csrc/init.cpp b/functorch/csrc/init.cpp index 35ce5f34f..b0699ce7c 100644 --- a/functorch/csrc/init.cpp +++ b/functorch/csrc/init.cpp @@ -379,6 +379,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("_set_vmap_fallback_warning_enabled", &at::functorch::setVmapFallbackWarningEnabled, "Set vmap fallback warnings"); m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled); m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled); + m.def("set_inplace_requires_grad_allowed", &at::functorch::setInplaceRequiresGradAllowed); m.def("dlevel", &at::functorch::dlevel, "dlevel"); m.def("dump_tensor", &at::functorch::dump_tensor, "dump_tensor"); m.def("reshape_dim_into", &at::functorch::reshape_dim_into); diff --git a/test/test_eager_transforms.py b/test/test_eager_transforms.py index 9d74f7541..df325f5aa 100644 --- a/test/test_eager_transforms.py +++ b/test/test_eager_transforms.py @@ -2157,6 +2157,93 @@ def f(x): new_cotangent = torch.randn(()) self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent)) + def test_requires_grad_inside_transform(self, device): + def f(x): + x.requires_grad_() + return x.sin().sum() + + x = torch.randn(3) + + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + vmap(f)(x) + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + grad(f)(x) + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + vmap(grad(f))(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + grad(grad(f))(x) + + def test_retain_grad_inside_transform(self, device): + def f(x): + y = x.sin() + y.retain_grad() + return y.sum() + + x = torch.randn(3) + + with self.assertRaisesRegex(RuntimeError, "Tensor.retain_grad()"): + grad(f)(x) + + def test_autograd_functional_jacrev_inside_transform(self, device): + def f(x): + y = torch.autograd.functional.jacobian(lambda x: x.sin().sum(), x) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + vmap(f)(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + grad(f)(x) + + def test_autograd_functional_vjp_inside_transform(self, device): + def f(x): + y = torch.autograd.functional.vjp(lambda x: x.sin().sum(), x) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + vmap(f)(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + grad(f)(x) + + def test_autograd_functional_jvp_inside_transform(self, device): + def f(x): + t = torch.ones_like(x) + y = torch.autograd.functional.jvp(lambda x: x.sin().sum(), (x,), (t,)) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + vmap(f)(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + grad(f)(x) + + def test_autograd_functional_jacfwd_inside_transform(self, device): + def f(x): + y = torch.autograd.functional.jacobian( + lambda x: x.sin().sum(), x, strategy='forward-mode', vectorize=True) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaises(RuntimeError): + vmap(f)(x) + + x = torch.randn([]) + with self.assertRaises(RuntimeError): + grad(f)(x) + class TestMakeFunctional(TestCase): @parametrize('disable_autograd_tracking', [True, False])