diff --git a/functorch/csrc/BatchRulesScatterOps.cpp b/functorch/csrc/BatchRulesScatterOps.cpp index d2fdc2c93..a3482a625 100644 --- a/functorch/csrc/BatchRulesScatterOps.cpp +++ b/functorch/csrc/BatchRulesScatterOps.cpp @@ -126,7 +126,7 @@ void index_put__batch_rule( auto values_ = moveBatchDimToFront(values, values_bdim); TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size()); std::vector> indices_ = batchIndices(indices, indices_bdims, self_.size(0), self_bdim, values_bdim); - at::index_put_(self_, List>(indices_), values, accumulate); + at::index_put_(self_, List>(indices_), values_, accumulate); } // plumbing done since we don't support List> in codegen @@ -158,6 +158,54 @@ Tensor& index_put__plumbing(Tensor & self, const List> & indice return self; } +void _index_put_impl__batch_rule( + Tensor& self, + optional self_bdim, + ArrayRef> indices, + ArrayRef> indices_bdims, + const Tensor& values, + optional values_bdim, + bool accumulate, + bool unsafe) { + if (!self_bdim.has_value()) { + vmapIncompatibleInplaceError("_index_put_impl_"); + } + auto self_ = moveBatchDimToFront(self, self_bdim); + auto values_ = moveBatchDimToFront(values, values_bdim); + TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size()); + std::vector> indices_ = batchIndices(indices, indices_bdims, self_.size(0), self_bdim, values_bdim); + at::_index_put_impl_(self_, List>(indices_), values_, accumulate, unsafe); +} + +// plumbing done since we don't support List> in codegen +Tensor& _index_put_impl__plumbing(Tensor & self, const List> & indices +, const Tensor & values, bool accumulate, bool unsafe) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + Tensor self_value; + optional self_bdim; + std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); + std::vector> indices_value; + std::vector> indices_bdims; + for (const auto&& indRef : indices) { + optional ind = indRef; + optional index; + optional index_bdim; + if (ind.has_value()) { + std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level); + } + indices_value.push_back(index); + indices_bdims.push_back(index_bdim); + } + Tensor values_value; + optional values_bdim; + std::tie(values_value, values_bdim) = unwrapTensorAtLevel(values, cur_level); + _index_put_impl__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate, unsafe); + return self; +} + namespace { template @@ -496,6 +544,7 @@ std::tuple> index_add_batch_rule( TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { m.impl("index.Tensor", index_plumbing); m.impl("index_put_", index_put__plumbing); + m.impl("_index_put_impl_", _index_put_impl__plumbing); m.impl("slice_scatter", slice_scatter_decomp); m.impl("select_scatter", select_scatter_decomp); m.impl("index_copy", index_copy_decomp); diff --git a/functorch/csrc/BatchedFallback.cpp b/functorch/csrc/BatchedFallback.cpp index adf085455..4e83e3f0b 100644 --- a/functorch/csrc/BatchedFallback.cpp +++ b/functorch/csrc/BatchedFallback.cpp @@ -67,7 +67,10 @@ static bool areAnyArgumentsTensorList(const at::FunctionSchema& schema) { return std::any_of( schema.arguments().begin(), schema.arguments().end(), - [] (const Argument& arg) { return arg.type()->isSubtypeOf(ListType::ofTensors()); }); + [] (const Argument& arg) { + return arg.type()->isSubtypeOf(ListType::ofTensors()) || + arg.type()->isSubtypeOf(ListType::ofOptionalTensors()); + }); } static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) { diff --git a/test/functorch_additional_op_db.py b/test/functorch_additional_op_db.py index 163c3bb4e..6f4eae1be 100644 --- a/test/functorch_additional_op_db.py +++ b/test/functorch_additional_op_db.py @@ -258,3 +258,42 @@ def generator(): sample_inputs_func=sample_inputs_embedding, supports_out=False, )) + + +def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs): + S = 5 + test_args = [ + ([1, 2],), + (slice(0, 3),), + ([slice(0, 3), 1],), + ([[0, 2, 3], [1, 3, 3], [0, 0, 2]],), + ([[0, 0, 3], [1, 1, 3], [0, 0, 2]],), + ([slice(None), slice(None), [0, 3]],), + ([slice(None), [0, 3], slice(None)],), + ([[0, 3], slice(None), slice(None)],), + ([[0, 3], [1, 2], slice(None)],), + ([[0, 3], ],), + ([[0, 3], slice(None)],), + ([[0, 3], Ellipsis],), + # index_backward is not CompositeCompliant TODO. + # ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],), + ] + + return tuple(SampleInput( + make_tensor((S, S, S), device, dtype, low=None, high=None, requires_grad=requires_grad), + args=args) + for args in test_args) + + +# TODO: split PyTorch's __getitem__. The problem is we don't support indexing +# with masks with vmap. +additional_op_db.append( + OpInfo('__getitem__', + variant_test_name='functorch', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_inplace_autograd=False, + supports_scripting=False, + op=torch.Tensor.__getitem__, + assert_jit_shape_analysis=False, # TODO: support index.Tensor() + sample_inputs_func=sample_inputs_getitem,)) diff --git a/test/test_ops.py b/test/test_ops.py index c4ab7d046..dc0b44e6a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -549,7 +549,7 @@ def vjp_of_vjp(*args_and_cotangents): xfail('nn.functional.fractional_max_pool3d'), xfail('as_strided'), xfail('nn.functional.fractional_max_pool2d'), - xfail('__getitem__'), + xfail('__getitem__', ''), xfail('index_put'), xfail('lu_solve'), }) @@ -744,7 +744,7 @@ def test_vmapjvpall(self, device, dtype, op): @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) @skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({ xfail('view_as_complex'), - xfail('__getitem__'), + xfail('__getitem__', ''), xfail('cholesky'), xfail('complex'), xfail('copysign'), @@ -865,7 +865,7 @@ def test(): # fallback path doesn't work xfail('H'), # All of the following are bugs and need to be fixed - xfail('__getitem__'), + xfail('__getitem__', ''), xfail('clamp', ''), xfail('dsplit'), xfail('fill_'), diff --git a/test/test_vmap.py b/test/test_vmap.py index c0e9ca42c..91cf226b9 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3219,7 +3219,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('to_sparse'), xfail('vdot'), xfail('vsplit'), - xfail('__getitem__'), + xfail('__getitem__', ''), xfail('all'), xfail('any'), xfail('count_nonzero'),