diff --git a/include/inplace_abn.h b/include/inplace_abn.h index 6397af9..4ea0365 100644 --- a/include/inplace_abn.h +++ b/include/inplace_abn.h @@ -132,8 +132,8 @@ struct ActivationFn { y = y_act; dy = dy_act; } else { - y = ::log1p(static_cast(y_act / activation_param)); dy = static_cast(dy_act * (y_act + activation_param)); + y = ::log1p(static_cast(y_act / activation_param)); } } }; diff --git a/inplace_abn/functions.py b/inplace_abn/functions.py index 7779bb9..59e2ac5 100644 --- a/inplace_abn/functions.py +++ b/inplace_abn/functions.py @@ -115,9 +115,11 @@ def forward( @once_differentiable def backward(ctx, dy_act): y_act, var, count, weight, bias = ctx.saved_tensors - # Call backward_reduce if we need to compute at least one of the gradients if any(ctx.needs_input_grad): + # remove memory overlaping to allow for in-place operation + dy_act = dy_act.contiguous() + # This overwrites y_act with xhat and dy_act with dy xhat, dy, sum_dy_local, sum_xhat_dy_local = _backend.backward_reduce( y_act, dy_act, diff --git a/src/inplace_abn.cpp b/src/inplace_abn.cpp index 5cfb656..60e0458 100644 --- a/src/inplace_abn.cpp +++ b/src/inplace_abn.cpp @@ -224,7 +224,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "iABN forward pass. This is an in-place operation w.r.t. x"); // Backward methods - m.def("backward_reduce", &backward_reduce, "First step of the backward pass"); + m.def("backward_reduce", &backward_reduce, "First step of the backward pass. This is an in-place operation w.r.t. y_act, dy_act,"); m.def( "backward_train", &backward_train, diff --git a/src/inplace_abn_cpu.cpp b/src/inplace_abn_cpu.cpp index 9240328..9d9d1a5 100644 --- a/src/inplace_abn_cpu.cpp +++ b/src/inplace_abn_cpu.cpp @@ -30,8 +30,8 @@ std::tuple backward_reduce_impl( float eps, float activation_param) { // Initialize output tensors - auto xhat_ = at::empty_like(y_act_); - auto dy_ = at::empty_like(y_act_); + auto &xhat_ = y_act_; // reuse + auto &dy_ = dy_act_; // reuse auto sum_dy_ = at::zeros({y_act_.size(1)}, y_act_.options()); auto sum_xhat_dy_ = at::zeros({y_act_.size(1)}, y_act_.options()); @@ -119,13 +119,14 @@ void forward_cpu( // Apply normalization auto abs_weight = weight.has_value() - ? weight.value().abs() + eps + ? weight.value().abs().add_(eps) : at::ones({mean.size(0)}, mean.options()); - auto inv_std = 1 / at::sqrt(var + eps); + auto inv_std = var.add(eps).sqrt_().reciprocal_(); auto scale = weight.has_value() ? abs_weight * inv_std : inv_std; - auto shift = weight.has_value() ? bias.value() - mean * abs_weight * inv_std - : -mean * inv_std; + auto shift = weight.has_value() ? (- mean).mul_(abs_weight).mul_(inv_std).add_(bias.value()) + : (- mean).mul_(inv_std); + inv_std = at::Tensor(); // free memory x.mul_(normalize_shape(scale)).add_(normalize_shape(shift)); @@ -187,8 +188,8 @@ void backward_cpu( normalize_shape(sum_xhat_dy / count.to(sum_xhat_dy.options())); auto mult = weight.has_value() - ? (weight.value().abs() + eps) / (var + eps).sqrt() - : 1 / (var + eps).sqrt(); + ? (weight.value().abs().add_(eps)).div_(var.add(eps).sqrt_()) + : var.add(eps).sqrt_().reciprocal_(); // dy = (dy - mean_dy - xhat * mean_xhat_dy) * mult dy.sub_(mean_dy).sub_(xhat * mean_xhat_dy).mul_(normalize_shape(mult)); diff --git a/src/inplace_abn_cuda.cu b/src/inplace_abn_cuda.cu index 1c9a4fd..43808fa 100644 --- a/src/inplace_abn_cuda.cu +++ b/src/inplace_abn_cuda.cu @@ -143,8 +143,8 @@ std::tuple backward_reduce_templ } // Initialize output tensors - auto xhat = at::empty_like(y_act); - auto dy = at::empty_like(y_act); + auto &xhat = y_act; // reuse + auto &dy = dy_act; // reuse auto sum_dy = at::empty({chn}, acc_options); auto sum_xhat_dy = at::empty({chn}, acc_options);