Skip to content

Commit

Permalink
Merge pull request #352 from stanford-futuredata/rerank_fix
Browse files Browse the repository at this point in the history
Do not do initial retrieval if pids are passed in
  • Loading branch information
bclavie authored Nov 14, 2024
2 parents 387be23 + 8881d98 commit 7067ef5
Showing 1 changed file with 45 additions and 44 deletions.
89 changes: 45 additions & 44 deletions colbert/search/index_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def rank(self, config, Q, filter_fn=None, pids=None):
if pids is None:
pids, centroid_scores = self.retrieve(config, Q)
else:
pids_, centroid_scores = self.retrieve(config, Q)
pids = torch.tensor(pids, dtype=pids_.dtype, device=pids_.device)
pids = torch.tensor(pids, dtype=torch.int32, device=Q.device)
centroid_scores = None

if filter_fn is not None:
filtered_pids = filter_fn(pids)
Expand Down Expand Up @@ -120,49 +120,50 @@ def score_pids(self, config, Q, pids, centroid_scores):
# TODO: Remove batching?
batch_size = 2 ** 20

if self.use_gpu:
centroid_scores = centroid_scores.cuda()

idx = centroid_scores.max(-1).values >= config.centroid_score_threshold

if self.use_gpu:
approx_scores = []

# Filter docs using pruned centroid scores
for i in range(0, ceil(len(pids) / batch_size)):
pids_ = pids[i * batch_size : (i+1) * batch_size]
codes_packed, codes_lengths = self.embeddings_strided.lookup_codes(pids_)
idx_ = idx[codes_packed.long()]
pruned_codes_strided = StridedTensor(idx_, codes_lengths, use_gpu=self.use_gpu)
pruned_codes_padded, pruned_codes_mask = pruned_codes_strided.as_padded_tensor()
pruned_codes_lengths = (pruned_codes_padded * pruned_codes_mask).sum(dim=1)
codes_packed_ = codes_packed[idx_]
approx_scores_ = centroid_scores[codes_packed_.long()]
if approx_scores_.shape[0] == 0:
approx_scores.append(torch.zeros((len(pids_),), dtype=approx_scores_.dtype).cuda())
continue
approx_scores_strided = StridedTensor(approx_scores_, pruned_codes_lengths, use_gpu=self.use_gpu)
if centroid_scores is not None:
if self.use_gpu:
centroid_scores = centroid_scores.cuda()

idx = centroid_scores.max(-1).values >= config.centroid_score_threshold

if self.use_gpu:
approx_scores = []

# Filter docs using pruned centroid scores
for i in range(0, ceil(len(pids) / batch_size)):
pids_ = pids[i * batch_size : (i+1) * batch_size]
codes_packed, codes_lengths = self.embeddings_strided.lookup_codes(pids_)
idx_ = idx[codes_packed.long()]
pruned_codes_strided = StridedTensor(idx_, codes_lengths, use_gpu=self.use_gpu)
pruned_codes_padded, pruned_codes_mask = pruned_codes_strided.as_padded_tensor()
pruned_codes_lengths = (pruned_codes_padded * pruned_codes_mask).sum(dim=1)
codes_packed_ = codes_packed[idx_]
approx_scores_ = centroid_scores[codes_packed_.long()]
if approx_scores_.shape[0] == 0:
approx_scores.append(torch.zeros((len(pids_),), dtype=approx_scores_.dtype).cuda())
continue
approx_scores_strided = StridedTensor(approx_scores_, pruned_codes_lengths, use_gpu=self.use_gpu)
approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor()
approx_scores_ = colbert_score_reduce(approx_scores_padded, approx_scores_mask, config)
approx_scores.append(approx_scores_)
approx_scores = torch.cat(approx_scores, dim=0)
assert approx_scores.is_cuda, approx_scores.device
if config.ndocs < len(approx_scores):
pids = pids[torch.topk(approx_scores, k=config.ndocs).indices]

# Filter docs using full centroid scores
codes_packed, codes_lengths = self.embeddings_strided.lookup_codes(pids)
approx_scores = centroid_scores[codes_packed.long()]
approx_scores_strided = StridedTensor(approx_scores, codes_lengths, use_gpu=self.use_gpu)
approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor()
approx_scores_ = colbert_score_reduce(approx_scores_padded, approx_scores_mask, config)
approx_scores.append(approx_scores_)
approx_scores = torch.cat(approx_scores, dim=0)
assert approx_scores.is_cuda, approx_scores.device
if config.ndocs < len(approx_scores):
pids = pids[torch.topk(approx_scores, k=config.ndocs).indices]

# Filter docs using full centroid scores
codes_packed, codes_lengths = self.embeddings_strided.lookup_codes(pids)
approx_scores = centroid_scores[codes_packed.long()]
approx_scores_strided = StridedTensor(approx_scores, codes_lengths, use_gpu=self.use_gpu)
approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor()
approx_scores = colbert_score_reduce(approx_scores_padded, approx_scores_mask, config)
if config.ndocs // 4 < len(approx_scores):
pids = pids[torch.topk(approx_scores, k=(config.ndocs // 4)).indices]
else:
pids = IndexScorer.filter_pids(
pids, centroid_scores, self.embeddings.codes, self.doclens,
self.offsets, idx, config.ndocs
)
approx_scores = colbert_score_reduce(approx_scores_padded, approx_scores_mask, config)
if config.ndocs // 4 < len(approx_scores):
pids = pids[torch.topk(approx_scores, k=(config.ndocs // 4)).indices]
else:
pids = IndexScorer.filter_pids(
pids, centroid_scores, self.embeddings.codes, self.doclens,
self.offsets, idx, config.ndocs
)

# Rank final list of docs using full approximate embeddings (including residuals)
if self.use_gpu:
Expand Down

0 comments on commit 7067ef5

Please sign in to comment.