-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ca1a91b
commit dd94777
Showing
6 changed files
with
876 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# PECOS XMR Reranker | ||
|
||
This is a reranker for the PECOS XMR model. It is based on huggingface's transformers library. The reranker can be run in both | ||
single process and distributed mode. It is based on the paper [Fine-Tuning LLaMA for Multi-Stage Text Retrieval](https://arxiv.org/abs/2310.08319). | ||
|
||
## How to run | ||
### Single process | ||
To run the reranker in single process mode, you can use the following command: | ||
|
||
```bash | ||
python -m pecos.xmr.reranker.train --config_json_path <path_to_config_file> | ||
``` | ||
|
||
### Distributed mode | ||
To run the reranker in distributed mode, you can use the following command to initialize the distributed configuration: | ||
```bash | ||
accelerate config | ||
``` | ||
|
||
Then you can run the reranker using the following command: | ||
```bash | ||
accelerate launch -m pecos.xmr.reranker.train --config_json_path <path_to_config_file> | ||
``` | ||
|
||
## Configuration file | ||
Here is an example of the configuration file: | ||
```json | ||
{ | ||
"train_params": { | ||
"__meta__": { | ||
"class_fullname": "pecos.xmr.reranker.model###RankingModel.TrainParams" | ||
}, | ||
"target_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/target", | ||
"input_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/input", | ||
"label_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/label", | ||
"training_args": { | ||
"__meta__": { | ||
"class_fullname": "pecos.xmr.reranker.trainer###RankLlamaTrainer.TrainingArgs" | ||
}, | ||
"learning_rate": 1e-4, | ||
"output_dir": "./ds_model", | ||
"per_device_train_batch_size": 8, | ||
"gradient_accumulation_steps": 8, | ||
"max_steps": -1, | ||
"logging_strategy": "steps", | ||
"logging_first_step": false, | ||
"logging_steps": 10, | ||
"save_strategy": "steps", | ||
"save_steps": 50, | ||
"save_total_limit": 5, | ||
"seed": 42, | ||
"data_seed": 42, | ||
"bf16": true, | ||
"dataloader_num_workers": 2, | ||
"dataloader_prefetch_factor": 10, | ||
"gradient_checkpointing": true, | ||
"train_group_size": 16 | ||
} | ||
}, | ||
"model_params": { | ||
"__meta__": { | ||
"class_fullname": "pecos.xmr.reranker.model###RankingModel.ModelParams" | ||
}, | ||
"encoder_args": { | ||
"__meta__": { | ||
"class_fullname": "pecos.xmr.reranker.model###CrossEncoder.Config" | ||
}, | ||
"model_shortcut": "meta-llama/Llama-2-7b-hf", | ||
"model_init_kwargs": {}, | ||
"model_modifier": { | ||
"modifier_type": "peft", | ||
"config_type": "LoraConfig" , | ||
"config": { | ||
"r": 8, | ||
"lora_alpha": 64, | ||
"target_modules": ["q_proj", "v_proj"], | ||
"modules_to_save": ["score", "classifier"], | ||
"lora_dropout": 0.1 | ||
} | ||
} | ||
}, | ||
"positive_passage_no_shuffle": false, | ||
"negative_passage_no_shuffle": false, | ||
"rerank_max_len": 196, | ||
"query_prefix": "query: ", | ||
"passage_prefix": "document: ", | ||
"append_eos_token": false, | ||
"pad_to_multiple_of": 16 | ||
} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import os | ||
import random | ||
from collections import OrderedDict | ||
from typing import List, Tuple, Callable | ||
|
||
import numpy as np | ||
import pyarrow.parquet as pq | ||
from datasets import load_dataset | ||
|
||
import pecos | ||
|
||
|
||
class RankingDataUtils(pecos.BaseClass): | ||
""" | ||
Utility class for handling data related tasks | ||
""" | ||
|
||
@classmethod | ||
def remap_ordereddict(cls, od: OrderedDict, keymap_fn: Callable): | ||
""" | ||
Function to remap the keys of an ordered Dictionary | ||
Args: | ||
od: The ordered dictionary to remap | ||
keymap_fn: The function to map the keys | ||
""" | ||
new_od = OrderedDict() | ||
for k, v in od.items(): | ||
new_od[keymap_fn(k)] = v | ||
return new_od | ||
|
||
@classmethod | ||
def _format_sample( | ||
cls, | ||
inp_text: str, | ||
lbl_text: str, | ||
lbl_title: str, | ||
inp_prefix: str = "...", | ||
passage_prefix: str = "...", | ||
) -> str: | ||
""" | ||
Function to convert the text fields into a formatted string | ||
that the model understands. | ||
""" | ||
lbl_title = lbl_title.replace("-", " ").strip() | ||
return f"{inp_prefix} {inp_text} {passage_prefix} {lbl_title} {lbl_text}".strip() | ||
|
||
@classmethod | ||
def _create_sample( | ||
cls, | ||
inp_id: int, | ||
ret_idxs: List[int], | ||
scores: List[float], | ||
table_stores, | ||
train_group_size: int, | ||
inp_prefix: str, | ||
passage_prefix: str, | ||
) -> Tuple[List[str], List[float]]: | ||
""" | ||
Function to create a sample for training. | ||
Args: | ||
inp_id: The input id | ||
ret_idxs: The retrieved indices | ||
scores: Scores for the retrieved indices | ||
table_stores: Dictionary of table stores for input and label data | ||
train_group_size: The number of passages used to train for each query | ||
inp_prefix: The input prefix | ||
passage_prefix: The passage prefix | ||
Returns: A tuple of formatted samples and scores | ||
""" | ||
qid = inp_id | ||
pidxs = ret_idxs | ||
|
||
input_store = table_stores["input"] | ||
label_store = table_stores["label"] | ||
|
||
# get the values of the query | ||
query = input_store[qid]["keywords"] | ||
mean_score = np.mean(scores) | ||
|
||
# get idxs for positive items | ||
pos_idxs = [(x, pid) for x, pid in zip(scores, pidxs) if x > mean_score] | ||
neg_idxs = [(x, pid) for x, pid in zip(scores, pidxs) if x <= mean_score] | ||
random.shuffle(pos_idxs) | ||
random.shuffle(neg_idxs) | ||
|
||
num_positives = train_group_size // 2 | ||
|
||
all_selections = pos_idxs[:num_positives] | ||
num_positives = len(all_selections) | ||
num_negatives = train_group_size - num_positives | ||
all_selections.extend(neg_idxs[:num_negatives]) | ||
|
||
if len(all_selections) < train_group_size: | ||
all_selections.extend( | ||
random.choices(neg_idxs, k=train_group_size - len(all_selections)) | ||
) | ||
|
||
all_scores = [s for s, _ in all_selections] | ||
all_pids = [pid for _, pid in all_selections] | ||
|
||
# get the values for the retrieved items | ||
ret_info = [label_store[i] for i in all_pids] | ||
|
||
formated_pair = [] | ||
for info in ret_info: | ||
formated_pair.append( | ||
cls._format_sample( | ||
query, info["contents"], info["title"], inp_prefix, passage_prefix | ||
) | ||
) | ||
return formated_pair, all_scores | ||
|
||
@classmethod | ||
def get_parquet_rows(cls, folder_path: str) -> int: | ||
""" | ||
Returns the count of rows in parquet files by reading the | ||
metadata | ||
""" | ||
file_list = os.listdir(folder_path) | ||
file_list = [os.path.join(folder_path, x) for x in file_list] | ||
cumulative_rowcount = sum([pq.read_metadata(fp).num_rows for fp in file_list]) | ||
|
||
return cumulative_rowcount | ||
|
||
@classmethod | ||
def get_sorted_data_files(cls, filenames: List[str], idx_colname) -> List[str]: | ||
""" | ||
Returns the list of files sorted by the id in the first row of each file | ||
""" | ||
# Load the datasets in streaming format and read the first id | ||
fn_ordered = [] # this containes tuples with (idx, filename) | ||
for fn in filenames: | ||
tmp_ds = load_dataset("parquet", data_files=fn, streaming=True, split="train") | ||
row = next(iter(tmp_ds.take(1))) | ||
fn_ordered.append((row[idx_colname], fn)) | ||
del tmp_ds | ||
fn_ordered = sorted(fn_ordered, key=lambda x: x[0]) | ||
|
||
return [x[1] for x in fn_ordered] |
Oops, something went wrong.