Skip to content

Commit

Permalink
added support for rankllama
Browse files Browse the repository at this point in the history
  • Loading branch information
aniquetahir committed Aug 6, 2024
1 parent ca1a91b commit dd94777
Show file tree
Hide file tree
Showing 6 changed files with 876 additions and 2 deletions.
91 changes: 91 additions & 0 deletions pecos/xmr/reranker/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# PECOS XMR Reranker

This is a reranker for the PECOS XMR model. It is based on huggingface's transformers library. The reranker can be run in both
single process and distributed mode. It is based on the paper [Fine-Tuning LLaMA for Multi-Stage Text Retrieval](https://arxiv.org/abs/2310.08319).

## How to run
### Single process
To run the reranker in single process mode, you can use the following command:

```bash
python -m pecos.xmr.reranker.train --config_json_path <path_to_config_file>
```

### Distributed mode
To run the reranker in distributed mode, you can use the following command to initialize the distributed configuration:
```bash
accelerate config
```

Then you can run the reranker using the following command:
```bash
accelerate launch -m pecos.xmr.reranker.train --config_json_path <path_to_config_file>
```

## Configuration file
Here is an example of the configuration file:
```json
{
"train_params": {
"__meta__": {
"class_fullname": "pecos.xmr.reranker.model###RankingModel.TrainParams"
},
"target_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/target",
"input_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/input",
"label_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/label",
"training_args": {
"__meta__": {
"class_fullname": "pecos.xmr.reranker.trainer###RankLlamaTrainer.TrainingArgs"
},
"learning_rate": 1e-4,
"output_dir": "./ds_model",
"per_device_train_batch_size": 8,
"gradient_accumulation_steps": 8,
"max_steps": -1,
"logging_strategy": "steps",
"logging_first_step": false,
"logging_steps": 10,
"save_strategy": "steps",
"save_steps": 50,
"save_total_limit": 5,
"seed": 42,
"data_seed": 42,
"bf16": true,
"dataloader_num_workers": 2,
"dataloader_prefetch_factor": 10,
"gradient_checkpointing": true,
"train_group_size": 16
}
},
"model_params": {
"__meta__": {
"class_fullname": "pecos.xmr.reranker.model###RankingModel.ModelParams"
},
"encoder_args": {
"__meta__": {
"class_fullname": "pecos.xmr.reranker.model###CrossEncoder.Config"
},
"model_shortcut": "meta-llama/Llama-2-7b-hf",
"model_init_kwargs": {},
"model_modifier": {
"modifier_type": "peft",
"config_type": "LoraConfig" ,
"config": {
"r": 8,
"lora_alpha": 64,
"target_modules": ["q_proj", "v_proj"],
"modules_to_save": ["score", "classifier"],
"lora_dropout": 0.1
}
}
},
"positive_passage_no_shuffle": false,
"negative_passage_no_shuffle": false,
"rerank_max_len": 196,
"query_prefix": "query: ",
"passage_prefix": "document: ",
"append_eos_token": false,
"pad_to_multiple_of": 16
}
}
```
141 changes: 141 additions & 0 deletions pecos/xmr/reranker/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import os
import random
from collections import OrderedDict
from typing import List, Tuple, Callable

import numpy as np
import pyarrow.parquet as pq
from datasets import load_dataset

import pecos


class RankingDataUtils(pecos.BaseClass):
"""
Utility class for handling data related tasks
"""

@classmethod
def remap_ordereddict(cls, od: OrderedDict, keymap_fn: Callable):
"""
Function to remap the keys of an ordered Dictionary
Args:
od: The ordered dictionary to remap
keymap_fn: The function to map the keys
"""
new_od = OrderedDict()
for k, v in od.items():
new_od[keymap_fn(k)] = v
return new_od

@classmethod
def _format_sample(
cls,
inp_text: str,
lbl_text: str,
lbl_title: str,
inp_prefix: str = "...",
passage_prefix: str = "...",
) -> str:
"""
Function to convert the text fields into a formatted string
that the model understands.
"""
lbl_title = lbl_title.replace("-", " ").strip()
return f"{inp_prefix} {inp_text} {passage_prefix} {lbl_title} {lbl_text}".strip()

@classmethod
def _create_sample(
cls,
inp_id: int,
ret_idxs: List[int],
scores: List[float],
table_stores,
train_group_size: int,
inp_prefix: str,
passage_prefix: str,
) -> Tuple[List[str], List[float]]:
"""
Function to create a sample for training.
Args:
inp_id: The input id
ret_idxs: The retrieved indices
scores: Scores for the retrieved indices
table_stores: Dictionary of table stores for input and label data
train_group_size: The number of passages used to train for each query
inp_prefix: The input prefix
passage_prefix: The passage prefix
Returns: A tuple of formatted samples and scores
"""
qid = inp_id
pidxs = ret_idxs

input_store = table_stores["input"]
label_store = table_stores["label"]

# get the values of the query
query = input_store[qid]["keywords"]
mean_score = np.mean(scores)

# get idxs for positive items
pos_idxs = [(x, pid) for x, pid in zip(scores, pidxs) if x > mean_score]
neg_idxs = [(x, pid) for x, pid in zip(scores, pidxs) if x <= mean_score]
random.shuffle(pos_idxs)
random.shuffle(neg_idxs)

num_positives = train_group_size // 2

all_selections = pos_idxs[:num_positives]
num_positives = len(all_selections)
num_negatives = train_group_size - num_positives
all_selections.extend(neg_idxs[:num_negatives])

if len(all_selections) < train_group_size:
all_selections.extend(
random.choices(neg_idxs, k=train_group_size - len(all_selections))
)

all_scores = [s for s, _ in all_selections]
all_pids = [pid for _, pid in all_selections]

# get the values for the retrieved items
ret_info = [label_store[i] for i in all_pids]

formated_pair = []
for info in ret_info:
formated_pair.append(
cls._format_sample(
query, info["contents"], info["title"], inp_prefix, passage_prefix
)
)
return formated_pair, all_scores

@classmethod
def get_parquet_rows(cls, folder_path: str) -> int:
"""
Returns the count of rows in parquet files by reading the
metadata
"""
file_list = os.listdir(folder_path)
file_list = [os.path.join(folder_path, x) for x in file_list]
cumulative_rowcount = sum([pq.read_metadata(fp).num_rows for fp in file_list])

return cumulative_rowcount

@classmethod
def get_sorted_data_files(cls, filenames: List[str], idx_colname) -> List[str]:
"""
Returns the list of files sorted by the id in the first row of each file
"""
# Load the datasets in streaming format and read the first id
fn_ordered = [] # this containes tuples with (idx, filename)
for fn in filenames:
tmp_ds = load_dataset("parquet", data_files=fn, streaming=True, split="train")
row = next(iter(tmp_ds.take(1)))
fn_ordered.append((row[idx_colname], fn))
del tmp_ds
fn_ordered = sorted(fn_ordered, key=lambda x: x[0])

return [x[1] for x in fn_ordered]
Loading

0 comments on commit dd94777

Please sign in to comment.