Skip to content

Commit

Permalink
Generate n^2 not n^3 inputs for batch and instance norm; small batch …
Browse files Browse the repository at this point in the history
…norm fix (#951)

* refactor batch norm exhaustive inputs

* fix typo in batch rule

* fix expand issue, add without cudnn xfail
  • Loading branch information
Samantha Andow authored Jul 12, 2022
1 parent 4f25800 commit ca3ac11
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 80 deletions.
8 changes: 7 additions & 1 deletion functorch/csrc/BatchRulesNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ batch_norm_batch_rule(
mean = std::get<1>(result);
rstd = std::get<2>(result);
} else {
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_mean_bdim);
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
auto input_ = moveBatchDimToFront(input, input_bdim);
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size.value());
input_ = reshape_dim_into(0, /*channels dim*/1, input_);
Expand All @@ -86,11 +86,17 @@ batch_norm_batch_rule(
running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim);
running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value());
running_mean_ = reshape_dim_into(0, 0, *running_mean_);
if (training) {
running_mean_ = running_mean_->contiguous();
}
}
if (running_var.defined()) {
running_var_ = moveBatchDimToFront(running_var, running_var_bdim);
running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value());
running_var_ = reshape_dim_into(0, 0, *running_var_);
if (training) {
running_var_ = running_var_->contiguous();
}
}

const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight
Expand Down
133 changes: 55 additions & 78 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,34 @@ def get_bdim_choices(num_tensors):
assert choices[-1] == (None,) * num_tensors
return tuple(choices[:-1])

# NB: This is O(2 ** num_tensors).
# num_tensors ranges from 1 to 10, with 2-4 being most common.
# Try not to extravagate it if you're modifying it.
def get_bdim_choices_batch_norm(num_tensors, _, running_mean=None, running_var=None, *args):
choices = []
options = (-1, None)

# instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified
if running_mean == None or running_var == None:
choices.append((None,) + (0,) * (num_tensors - 1))
for choice in itertools.product(options, repeat=num_tensors - 1):
choices.append((None,) + choice)

else:
# running_mean and running_var are specified as tensors. Batch norm doesn't work if the input is batched but
# running_mean/var are unbatched, so this tests all other cases
choices.append((0,) * num_tensors)
for choice in itertools.product(options, repeat=num_tensors):
input_bdim = choice[0]
running_mean_bdim = choice[1]
running_var_bdim = choice[2]
if input_bdim and (not running_mean_bdim or not running_var_bdim):
continue
choices.append(choice)

assert choices[-1] == (None,) * num_tensors
return tuple(choices[:-1])


def add_batch_dim(arg, bdim, batch_size=3):
assert bdim == 0 or bdim == -1
Expand All @@ -93,12 +121,7 @@ def construct_in_dims(bdim_choice_for_tensors, is_tensors):
result.append(next(bdim))
return tuple(result)


def get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=2, *, for_batch_norm=False):
if for_batch_norm:
# TODO: delete this path
return get_exhaustive_batched_inputs_batch_norm(arg_values, kwarg_values, batch_size)

def get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=2):
flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
bdim_choices = get_bdim_choices(sum(is_tensors))
Expand All @@ -120,87 +143,41 @@ def get_batched_arg(arg, bdim):
yield batched_args, in_dims, kwarg_values


def get_exhaustive_batched_inputs_batch_norm(arg_values, kwarg_values, batch_size=3, bdims=(0, -1)):
for_batch_norm = True
assert bdims == (0,) or bdims == (0, -1)

def add_batch_dim(arg, bdim, batch_size=3):
assert bdim == 0 or bdim == -1
if isinstance(arg, torch.Tensor):
if bdim == 0:
shape = [1] * len(arg.shape)
shape.insert(bdim, batch_size)
return (arg.repeat(shape), bdim)
if bdim == -1:
arg = arg.unsqueeze(-1).expand(*arg.shape, batch_size).contiguous()
return (arg, bdim)
assert False
else:
return (arg, None)
for bdim in bdims:
batch_choices = []

def add_batch_choices(a):
if isinstance(a, torch.Tensor):
batched_val = add_batch_dim(a, bdim, batch_size)
batch_choices.append((batched_val, (a, None)))
else:
batch_choices.append(((a, None),))

flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
if for_batch_norm:
# Batch norm is unique because the running_mean and running_var are updated in place.
# Therefore, they cannot be unbatched if the input is batched. The case where both are
# unbatched is added at the end
if len(flat_args) >= 3:
add_batch_choices(flat_args[0]) # input can be batched or unbatched
batch_choices.append((add_batch_dim(flat_args[1], bdim, batch_size),)) # running_mean must be batched
batch_choices.append((add_batch_dim(flat_args[2], bdim, batch_size),)) # running_var must be batched
orig_flat_args = flat_args
flat_args = orig_flat_args[3:]
else:
# TODO: None defaults in instance norm create empty tensors that are written to and mean that we must
# have unbatched inputs. None in the running mean/running var shouldn't make a tensor
batch_choices.append(((flat_args[0], None),)) # input must be unbatched
if len(flat_args) == 2:
batch_choices.append((add_batch_dim(flat_args[1], bdim, batch_size),))
orig_flat_args = flat_args
flat_args = []

for arg in flat_args:
add_batch_choices(arg)

for batched_values in itertools.product(*batch_choices):
batched_args, in_dims = zip(*batched_values)

if all([i is None for i in in_dims]):
continue

yield pytree.tree_unflatten(batched_args, arg_spec), pytree.tree_unflatten(in_dims, arg_spec), kwarg_values
def get_exhaustive_batched_inputs_batch_norm(arg_values, kwarg_values, batch_size=2):
flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
num_tensors = sum(is_tensors)
if num_tensors == 1: # if there's only an input, can't batch it since running_mean/var will be seen as unbatched tensors
return
bdim_choices = get_bdim_choices_batch_norm(num_tensors, *arg_values)

if for_batch_norm and len(orig_flat_args) >= 2:
# Adds the case where input, running_mean, and running_var are all unbatched
batch_choices[0] = ((orig_flat_args[0], None),)
batch_choices[1] = ((orig_flat_args[1], None),)
if len(orig_flat_args) >= 3:
batch_choices[2] = ((orig_flat_args[2], None),)
for batched_values in itertools.product(*batch_choices):
batched_args, in_dims = zip(*batched_values)
@memoize
def get_batched_arg(arg, bdim):
assert isinstance(arg, torch.Tensor)
assert bdim is not None
result, _ = add_batch_dim(arg, bdim, batch_size)
return result

if all([i is None for i in in_dims]):
continue
for bdim_choice in bdim_choices:
flat_in_dims = construct_in_dims(bdim_choice, is_tensors)

batched_args_tuple = pytree.tree_unflatten(batched_args, arg_spec)
in_dims_tuple = pytree.tree_unflatten(in_dims, arg_spec)
yield batched_args_tuple, in_dims_tuple, kwarg_values
flat_batched_args = tuple(arg if in_dim is None else get_batched_arg(arg, in_dim)
for arg, in_dim in zip(flat_args, flat_in_dims))
batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec)
in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec)
yield batched_args, in_dims, kwarg_values


def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, opinfo=None, compute_loop_out=True):
out_dim = 0
batch_size = 2
batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm
for_batch_norm = opinfo is not None and opinfo.name in batch_norm_fns
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, for_batch_norm=for_batch_norm)

if opinfo is not None and opinfo.name in batch_norm_fns:
generator = get_exhaustive_batched_inputs_batch_norm(arg_values, kwarg_values, batch_size)
else:
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size)

for batched_args, in_dims, kwarg_values in generator:
if compute_loop_out:
loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values)
Expand Down
10 changes: 9 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from common_utils import (
get_fallback_and_vmap_exhaustive,
get_exhaustive_batched_inputs,
get_exhaustive_batched_inputs_batch_norm,
xfail,
skip,
skipOps,
Expand Down Expand Up @@ -663,6 +664,10 @@ def test_vmapvjp(self, device, dtype, op):

xfail('put'), # calls put_ during vmap with only vmaps over other, not self
xfail('nn.functional.prelu'), # Call Tensor.as_strided

# erroring because running_mean and running_var aren't differentiable
xfail('nn.functional.batch_norm'),
xfail('nn.functional.batch_norm', 'without_cudnn'),
}

@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
Expand Down Expand Up @@ -964,7 +969,10 @@ def test_vjpvmap(self, device, dtype, op):
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
generator = get_exhaustive_batched_inputs(args, kwargs, for_batch_norm=is_batch_norm)
if is_batch_norm:
generator = get_exhaustive_batched_inputs_batch_norm(args, kwargs)
else:
generator = get_exhaustive_batched_inputs(args, kwargs)

for batched_args, in_dims, kwargs in generator:
vmapped_op = vmap(op, in_dims)
Expand Down

0 comments on commit ca3ac11

Please sign in to comment.