Skip to content

Commit

Permalink
Fix getitem (#364)
Browse files Browse the repository at this point in the history
Fixes #363

This PR:
- adds a batch rule for _index_put_impl_
- fixes the index_put_ batch rule
- adds a new OpInfo so we can actually test this
- fixes the fallback paths to error out on Tensor?[], otherwise they are
very wrong.
  • Loading branch information
zou3519 authored Dec 22, 2021
1 parent 3ba93d3 commit 1af1ae2
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 6 deletions.
51 changes: 50 additions & 1 deletion functorch/csrc/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void index_put__batch_rule(
auto values_ = moveBatchDimToFront(values, values_bdim);
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
std::vector<optional<Tensor>> indices_ = batchIndices(indices, indices_bdims, self_.size(0), self_bdim, values_bdim);
at::index_put_(self_, List<optional<Tensor>>(indices_), values, accumulate);
at::index_put_(self_, List<optional<Tensor>>(indices_), values_, accumulate);
}

// plumbing done since we don't support List<optional<Tensor>> in codegen
Expand Down Expand Up @@ -158,6 +158,54 @@ Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indice
return self;
}

void _index_put_impl__batch_rule(
Tensor& self,
optional<int64_t> self_bdim,
ArrayRef<optional<Tensor>> indices,
ArrayRef<optional<int64_t>> indices_bdims,
const Tensor& values,
optional<int64_t> values_bdim,
bool accumulate,
bool unsafe) {
if (!self_bdim.has_value()) {
vmapIncompatibleInplaceError("_index_put_impl_");
}
auto self_ = moveBatchDimToFront(self, self_bdim);
auto values_ = moveBatchDimToFront(values, values_bdim);
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
std::vector<optional<Tensor>> indices_ = batchIndices(indices, indices_bdims, self_.size(0), self_bdim, values_bdim);
at::_index_put_impl_(self_, List<optional<Tensor>>(indices_), values_, accumulate, unsafe);
}

// plumbing done since we don't support List<optional<Tensor>> in codegen
Tensor& _index_put_impl__plumbing(Tensor & self, const List<optional<Tensor>> & indices
, const Tensor & values, bool accumulate, bool unsafe) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
Tensor self_value;
optional<int64_t> self_bdim;
std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
std::vector<optional<Tensor>> indices_value;
std::vector<optional<int64_t>> indices_bdims;
for (const auto&& indRef : indices) {
optional<Tensor> ind = indRef;
optional<Tensor> index;
optional<int64_t> index_bdim;
if (ind.has_value()) {
std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level);
}
indices_value.push_back(index);
indices_bdims.push_back(index_bdim);
}
Tensor values_value;
optional<int64_t> values_bdim;
std::tie(values_value, values_bdim) = unwrapTensorAtLevel(values, cur_level);
_index_put_impl__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate, unsafe);
return self;
}

namespace {

template<typename Func, typename ...Args>
Expand Down Expand Up @@ -496,6 +544,7 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("index.Tensor", index_plumbing);
m.impl("index_put_", index_put__plumbing);
m.impl("_index_put_impl_", _index_put_impl__plumbing);
m.impl("slice_scatter", slice_scatter_decomp);
m.impl("select_scatter", select_scatter_decomp);
m.impl("index_copy", index_copy_decomp);
Expand Down
5 changes: 4 additions & 1 deletion functorch/csrc/BatchedFallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ static bool areAnyArgumentsTensorList(const at::FunctionSchema& schema) {
return std::any_of(
schema.arguments().begin(),
schema.arguments().end(),
[] (const Argument& arg) { return arg.type()->isSubtypeOf(ListType::ofTensors()); });
[] (const Argument& arg) {
return arg.type()->isSubtypeOf(ListType::ofTensors()) ||
arg.type()->isSubtypeOf(ListType::ofOptionalTensors());
});
}

static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
Expand Down
39 changes: 39 additions & 0 deletions test/functorch_additional_op_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,42 @@ def generator():
sample_inputs_func=sample_inputs_embedding,
supports_out=False,
))


def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
S = 5
test_args = [
([1, 2],),
(slice(0, 3),),
([slice(0, 3), 1],),
([[0, 2, 3], [1, 3, 3], [0, 0, 2]],),
([[0, 0, 3], [1, 1, 3], [0, 0, 2]],),
([slice(None), slice(None), [0, 3]],),
([slice(None), [0, 3], slice(None)],),
([[0, 3], slice(None), slice(None)],),
([[0, 3], [1, 2], slice(None)],),
([[0, 3], ],),
([[0, 3], slice(None)],),
([[0, 3], Ellipsis],),
# index_backward is not CompositeCompliant TODO.
# ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],),
]

return tuple(SampleInput(
make_tensor((S, S, S), device, dtype, low=None, high=None, requires_grad=requires_grad),
args=args)
for args in test_args)


# TODO: split PyTorch's __getitem__. The problem is we don't support indexing
# with masks with vmap.
additional_op_db.append(
OpInfo('__getitem__',
variant_test_name='functorch',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_inplace_autograd=False,
supports_scripting=False,
op=torch.Tensor.__getitem__,
assert_jit_shape_analysis=False, # TODO: support index.Tensor()
sample_inputs_func=sample_inputs_getitem,))
6 changes: 3 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def vjp_of_vjp(*args_and_cotangents):
xfail('nn.functional.fractional_max_pool3d'),
xfail('as_strided'),
xfail('nn.functional.fractional_max_pool2d'),
xfail('__getitem__'),
xfail('__getitem__', ''),
xfail('index_put'),
xfail('lu_solve'),
})
Expand Down Expand Up @@ -744,7 +744,7 @@ def test_vmapjvpall(self, device, dtype, op):
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({
xfail('view_as_complex'),
xfail('__getitem__'),
xfail('__getitem__', ''),
xfail('cholesky'),
xfail('complex'),
xfail('copysign'),
Expand Down Expand Up @@ -865,7 +865,7 @@ def test():
# fallback path doesn't work
xfail('H'),
# All of the following are bugs and need to be fixed
xfail('__getitem__'),
xfail('__getitem__', ''),
xfail('clamp', ''),
xfail('dsplit'),
xfail('fill_'),
Expand Down
2 changes: 1 addition & 1 deletion test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3219,7 +3219,7 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('to_sparse'),
xfail('vdot'),
xfail('vsplit'),
xfail('__getitem__'),
xfail('__getitem__', ''),
xfail('all'),
xfail('any'),
xfail('count_nonzero'),
Expand Down

0 comments on commit 1af1ae2

Please sign in to comment.