Skip to content

Commit

Permalink
[skip ci] WIP on index_fill batch rule
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Dec 27, 2021
1 parent 1af1ae2 commit e232281
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 5 deletions.
106 changes: 106 additions & 0 deletions functorch/csrc/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,110 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
return std::make_tuple(at::stack(results), 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Scalar& value) {

// std::cout << "index_fill_int_scalar_batch_rule:" << std::endl;
if (!index_bdim) {
// Handle scalar tensors... self, other can be scalar tensors
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
if (self_logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}
dim = maybe_wrap_dim(dim, self_logical_rank);

optional<int64_t> out_bdim = nullopt;
if (self_bdim) {
const auto batch_size = self.size(*self_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
dim = dim + 1;
out_bdim = 0;
}

// std::cout << "1 index_fill, self_: " << self_.sizes() << " index: " << index.sizes() << std::endl;
auto result = self_.index_fill(dim, index, value);
if (self_logical_rank == 0) {
result = result.squeeze(-1);
}
return std::make_tuple(result, out_bdim);
}

// SAME AS FOR index_add
// Index is batched. For-loop and stack is the best thing I can come up with
// right now. We really want generalized index_fill kernel in PyTorch
auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
std::vector<Tensor> results;
results.reserve(batch_size);
// std::cout << "2 index_fill loop: " << std::endl;
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_bdim.has_value() ?
self.select(*self_bdim, i) : self;
const auto& index_slice = index_bdim.has_value() ?
index.select(*index_bdim, i) : index;
// std::cout << i << " self_: " << self_slice.sizes() << " index: " << index_slice.sizes() << std::endl;
results.push_back(at::index_fill(self_slice, dim, index_slice, value));
}
return std::make_tuple(at::stack(results), 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Tensor& value, optional<int64_t> value_bdim) {

// std::cout << "index_fill_int_tensor_batch_rule: "
// << ((index_bdim) ? "true" : "false") << " "
// << ((value_bdim) ? "true" : "false") << " "
// << std::endl;
if (!index_bdim && !value_bdim) {
// Handle scalar tensors... self, other can be scalar tensors
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
if (self_logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}
dim = maybe_wrap_dim(dim, self_logical_rank);

optional<int64_t> out_bdim = nullopt;
if (self_bdim) {
const auto batch_size = self.size(*self_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
dim = dim + 1;
out_bdim = 0;
}
// std::cout << "1 index_fill, self_: " << self_.sizes() << " index: " << index.sizes() << std::endl;
auto result = self_.index_fill(dim, index, value);
if (self_logical_rank == 0) {
result = result.squeeze(-1);
}
return std::make_tuple(result, out_bdim);
}

// SAME AS FOR index_add
// Index is batched. For-loop and stack is the best thing I can come up with
// right now. We really want generalized index_fill kernel in PyTorch
auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
std::vector<Tensor> results;
results.reserve(batch_size);
// std::cout << "2 index_fill loop: " << std::endl;
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_bdim.has_value() ?
self.select(*self_bdim, i) : self;
const auto& index_slice = index_bdim.has_value() ?
index.select(*index_bdim, i) : index;
const auto& value_slice = value_bdim.has_value() ?
value.select(*value_bdim, i) : value;
// std::cout << i << " self_: " << self_slice.sizes() << " index: " << index_slice.sizes() << " value: " << value_slice.sizes() << std::endl;
results.push_back(at::index_fill(self_slice, dim, index_slice, value_slice));
}
return std::make_tuple(at::stack(results), 0);
}

TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("index.Tensor", index_plumbing);
m.impl("index_put_", index_put__plumbing);
Expand All @@ -550,6 +654,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("index_copy", index_copy_decomp);
m.impl("index_select", index_select_decomp);
VMAP_SUPPORT("index_add", index_add_batch_rule);
VMAP_SUPPORT("index_fill.int_Scalar", index_fill_int_scalar_batch_rule);
VMAP_SUPPORT("index_fill.int_Tensor", index_fill_int_tensor_batch_rule);
VMAP_SUPPORT("diagonal_scatter", diagonal_scatter_batch_rule);
VMAP_SUPPORT("gather", gather_batch_rule);
VMAP_SUPPORT("gather_backward", gather_backward_batch_rule);
Expand Down
5 changes: 1 addition & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ def vjp_of_vjp(*args_and_cotangents):
xfail('fmax'),
xfail('fmin'),
xfail('index_copy'),
xfail('index_fill'),
xfail('linalg.det', ''),
xfail('linalg.eigh'),
xfail('linalg.householder_product'),
Expand Down Expand Up @@ -595,7 +594,6 @@ def test_vmapvjp(self, device, dtype, op):
xfail('block_diag'), # TODO: We expect this to fail in core, but it doesn't
xfail('index_copy'),
xfail('index_put'),
xfail('index_fill'),
xfail('masked_fill'),
xfail('masked_scatter'),
Expand Down Expand Up @@ -701,7 +699,6 @@ def test_vmapjvp(self, device, dtype, op):
xfail('max', 'binary'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('min', 'binary'),
xfail('index_fill'),
xfail('index_put'),
xfail('std_mean'),
xfail('double', 'channels_last'),
Expand Down Expand Up @@ -760,7 +757,7 @@ def test_vmapjvpall(self, device, dtype, op):
xfail('fmax'),
xfail('fmin'),
xfail('index_copy'),
xfail('index_fill'),
xfail('index_fill'), # RuntimeError: aten::_unique hit the vmap fallback which is currently disabled
xfail('linalg.cholesky'),
xfail('linalg.cholesky_ex'),
xfail('linalg.det'),
Expand Down
1 change: 0 additions & 1 deletion test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3181,7 +3181,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('gradient'),
xfail('histogram'),
xfail('hsplit'),
xfail('index_fill'),
xfail('index_put'),
xfail('isin'),
xfail('linalg.cholesky'),
Expand Down

0 comments on commit e232281

Please sign in to comment.