diff --git a/colbert/indexing/utils.py b/colbert/indexing/utils.py index 92f0954b..b7f80507 100644 --- a/colbert/indexing/utils.py +++ b/colbert/indexing/utils.py @@ -44,6 +44,15 @@ def optimize_ivf(orig_ivf, orig_ivf_lengths, index_path, verbose:int=3): offset += length ivf = torch.cat(unique_pids_per_centroid) ivf_lengths = torch.tensor(ivf_lengths) + + max_stride = ivf_lengths.max().item() + zero = torch.zeros(1, dtype=torch.long, device=ivf_lengths.device) + offsets = torch.cat((zero, torch.cumsum(ivf_lengths, dim=0))) + inner_dims = ivf.size()[1:] + + if offsets[-2] + max_stride > ivf.size(0): + padding = torch.zeros(max_stride, *inner_dims, dtype=ivf.dtype, device=ivf.device) + ivf = torch.cat((ivf, padding)) original_ivf_path = os.path.join(index_path, 'ivf.pt') optimized_ivf_path = os.path.join(index_path, 'ivf.pid.pt') diff --git a/colbert/search/decompress_residuals.cpp b/colbert/search/decompress_residuals.cpp index fbca4f80..f6ef4451 100644 --- a/colbert/search/decompress_residuals.cpp +++ b/colbert/search/decompress_residuals.cpp @@ -156,5 +156,5 @@ torch::Tensor decompress_residuals( PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decompress_residuals_cpp", &decompress_residuals, - "Decompress residuals"); + "Decompress residuals", py::call_guard()); } diff --git a/colbert/search/filter_pids.cpp b/colbert/search/filter_pids.cpp index a85e4425..57564c64 100644 --- a/colbert/search/filter_pids.cpp +++ b/colbert/search/filter_pids.cpp @@ -170,5 +170,5 @@ torch::Tensor filter_pids(const torch::Tensor pids, } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("filter_pids_cpp", &filter_pids, "Filter pids"); + m.def("filter_pids_cpp", &filter_pids, "Filter pids", py::call_guard()); }