-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ca1a91b
commit 725f043
Showing
7 changed files
with
1,300 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# PECOS XMR Reranker | ||
|
||
This is a reranker for the PECOS XMR model. It is based on huggingface's transformers library. The reranker can be run in both | ||
single process and distributed mode. It is based on the paper [Fine-Tuning LLaMA for Multi-Stage Text Retrieval](https://arxiv.org/abs/2310.08319). | ||
|
||
## How to run | ||
### Single process | ||
To run the reranker in single process mode, you can use the following command: | ||
|
||
```bash | ||
python -m pecos.xmr.reranker.train --config_json_path <path_to_config_file> | ||
``` | ||
|
||
### Distributed mode | ||
To run the reranker in distributed mode, you can use the following command to initialize the distributed configuration: | ||
```bash | ||
accelerate config | ||
``` | ||
|
||
Then you can run the reranker using the following command: | ||
```bash | ||
accelerate launch -m pecos.xmr.reranker.train --config_json_path <path_to_config_file> | ||
``` | ||
|
||
### Predictions | ||
To run the reranker in prediction mode, you can use the following command: | ||
```bash | ||
python -m pecos.xmr.reranker.predict --config_json_path <path_to_config_file> | ||
``` | ||
|
||
## Configuration file | ||
|
||
### Training | ||
Here is an example of the configuration file for training: | ||
```json | ||
{ | ||
"train_params": { | ||
"__meta__": { | ||
"class_fullname": "pecos.xmr.reranker.model###RankingModel.TrainParams" | ||
}, | ||
"target_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/target", | ||
"input_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/input", | ||
"label_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/label", | ||
"training_args": { | ||
"__meta__": { | ||
"class_fullname": "pecos.xmr.reranker.trainer###RankLlamaTrainer.TrainingArgs" | ||
}, | ||
"learning_rate": 1e-4, | ||
"output_dir": "./ds_model", | ||
"per_device_train_batch_size": 8, | ||
"gradient_accumulation_steps": 8, | ||
"max_steps": -1, | ||
"logging_strategy": "steps", | ||
"logging_first_step": false, | ||
"logging_steps": 10, | ||
"save_strategy": "steps", | ||
"save_steps": 50, | ||
"save_total_limit": 5, | ||
"seed": 42, | ||
"data_seed": 42, | ||
"bf16": true, | ||
"dataloader_num_workers": 2, | ||
"dataloader_prefetch_factor": 10, | ||
"gradient_checkpointing": true, | ||
"train_group_size": 16 | ||
} | ||
}, | ||
"model_params": { | ||
"__meta__": { | ||
"class_fullname": "pecos.xmr.reranker.model###RankingModel.ModelParams" | ||
}, | ||
"encoder_args": { | ||
"__meta__": { | ||
"class_fullname": "pecos.xmr.reranker.model###CrossEncoder.Config" | ||
}, | ||
"model_shortcut": "meta-llama/Llama-2-7b-hf", | ||
"model_init_kwargs": {}, | ||
"model_modifier": { | ||
"modifier_type": "peft", | ||
"config_type": "LoraConfig" , | ||
"config": { | ||
"r": 8, | ||
"lora_alpha": 64, | ||
"target_modules": ["q_proj", "v_proj"], | ||
"modules_to_save": ["score", "classifier"], | ||
"lora_dropout": 0.1 | ||
} | ||
} | ||
}, | ||
"positive_passage_no_shuffle": false, | ||
"negative_passage_no_shuffle": false, | ||
"rerank_max_len": 196, | ||
"query_prefix": "query: ", | ||
"passage_prefix": "document: ", | ||
"inp_id_col": "inp_id", | ||
"lbl_idxs_col": "ret_idxs", | ||
"score_col": "rel", | ||
"keyword_col_name": "keywords", | ||
"content_col_names": ["title", "contents"], | ||
"append_eos_token": false, | ||
"pad_to_multiple_of": 16 | ||
} | ||
} | ||
``` | ||
|
||
### Prediction | ||
Following is the example of the configuration file for prediction: | ||
```json | ||
{ | ||
"model_name_or_path": "/tmp/pecosdev/ds_model", | ||
"target_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/target", | ||
"input_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/input", | ||
"label_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/label", | ||
"output_dir": "/tmp/xmrout", | ||
"per_device_eval_batch_size": 512, | ||
"dataloader_num_workers": 1, | ||
"dataloader_prefetch_factor": 10, | ||
"rerank_max_len": 196, | ||
"query_prefix": "query: ", | ||
"passage_prefix": "document: ", | ||
"inp_id_col": "inp_id", | ||
"lbl_id_col": "lbl_id", | ||
"keyword_col_name": "keywords", | ||
"content_col_names": ["title", "contents"], | ||
"append_eos_token": false, | ||
"pad_to_multiple_of": 8, | ||
"device": "cuda", | ||
"model_init_kwargs": { | ||
"device_map": "auto" | ||
} | ||
} | ||
``` | ||
|
||
## Data Schema | ||
The column names for the data schema are configurable through the json configuration file. Following | ||
are the various schemas that are supported by the reranker: | ||
|
||
(1) Learning Target Schema | ||
``` | ||
# +-----------------+---------------+-----------------------+ | ||
# | Column Name | Data Type | Description | | ||
# +-----------------+---------------+-----------------------+ | ||
# | inp_id | int32 | input id | | ||
# | lbl_id | array<int32> | an array of label_id | | ||
# | score | array<float> | an array of rel_score | | ||
# +-----------------+---------------+-----------------------+ | ||
``` | ||
|
||
(2) Input Feature Store Schema | ||
``` | ||
# +-----------------+---------------+-----------------------+ | ||
# | Column Name | Data Type | Description | | ||
# +-----------------+---------------+-----------------------+ | ||
# | inp_id | int32 | input id | | ||
# | keywords | string | keyword string | | ||
# +-----------------+---------------+-----------------------+ | ||
``` | ||
|
||
(3) Label Feature Store Schema | ||
|
||
The label feature store supports variable number of columns. The column names | ||
can be provided in the configuration file. | ||
``` | ||
# +-----------------+---------------+-----------------------+ | ||
# | Column Name | Data Type | Description | | ||
# +-----------------+---------------+-----------------------+ | ||
# | lbl_id | int32 | input id | | ||
# | title | string | title text | | ||
# | content | string | content string | | ||
# | ... | string | content string | | ||
# +-----------------+---------------+-----------------------+ | ||
``` | ||
|
||
(4) Evaluation Schema | ||
``` | ||
# +-----------------+---------------+-----------------------+ | ||
# | Column Name | Data Type | Description | | ||
# +-----------------+---------------+-----------------------+ | ||
# | inp_id | int32 | input id | | ||
# | lbl_id | int32 | label_id | | ||
# +-----------------+---------------+-----------------------+ | ||
``` | ||
|
||
(5) Evaluation Input Feature Store Schema | ||
``` | ||
# +-----------------+---------------+-----------------------+ | ||
# | Column Name | Data Type | Description | | ||
# +-----------------+---------------+-----------------------+ | ||
# | inp_id | int32 | input id | | ||
# | keywords | string | keyword string | | ||
# +-----------------+---------------+-----------------------+ | ||
``` | ||
|
||
(6) Evaluation Label Feature Store Schema | ||
``` | ||
# +-----------------+---------------+-----------------------+ | ||
# | Column Name | Data Type | Description | | ||
# +-----------------+---------------+-----------------------+ | ||
# | lbl_id | int32 | input id | | ||
# | title | string | title text | | ||
# | content | string | content string | | ||
# | ... | string | content string | | ||
# +-----------------+---------------+-----------------------+ | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import os | ||
import random | ||
from collections import OrderedDict | ||
from typing import List, Tuple, Callable | ||
|
||
import numpy as np | ||
import pyarrow.parquet as pq | ||
from datasets import load_dataset | ||
|
||
import pecos | ||
|
||
|
||
class RankingDataUtils(pecos.BaseClass): | ||
""" | ||
Utility class for handling data related tasks | ||
""" | ||
|
||
@classmethod | ||
def remap_ordereddict(cls, od: OrderedDict, keymap_fn: Callable): | ||
""" | ||
Function to remap the keys of an ordered Dictionary | ||
Args: | ||
od: The ordered dictionary to remap | ||
keymap_fn: The function to map the keys | ||
""" | ||
new_od = OrderedDict() | ||
for k, v in od.items(): | ||
new_od[keymap_fn(k)] = v | ||
return new_od | ||
|
||
@classmethod | ||
def _format_sample( | ||
cls, | ||
inp_text: str, | ||
lbl_contents: List[str], | ||
inp_prefix: str = "...", | ||
passage_prefix: str = "...", | ||
content_sep=" ", | ||
) -> str: | ||
""" | ||
Function to convert the text fields into a formatted string | ||
that the model understands. | ||
Args: | ||
inp_text: The input text | ||
lbl_contents: The list of content fields | ||
inp_prefix: The input prefix | ||
passage_prefix: The passage prefix | ||
content_sep: The separator between the content fields | ||
Returns: The formatted string | ||
""" | ||
# Convention from rankllama is to replace hyphens in the title | ||
lbl_contents[0] = lbl_contents[0].replace("-", " ").strip() | ||
return f"{inp_prefix} {inp_text} {passage_prefix} {content_sep.join(lbl_contents)}".strip() | ||
|
||
@classmethod | ||
def _create_sample( | ||
cls, | ||
inp_id: int, | ||
ret_idxs: List[int], | ||
scores: List[float], | ||
table_stores, | ||
train_group_size: int, | ||
inp_prefix: str, | ||
passage_prefix: str, | ||
keyword_col_name: str, | ||
content_col_names: List[str], | ||
content_sep, | ||
) -> Tuple[List[str], List[float]]: | ||
""" | ||
Function to create a sample for training. | ||
Args: | ||
inp_id: The input id | ||
ret_idxs: The retrieved indices | ||
scores: Scores for the retrieved indices | ||
table_stores: Dictionary of table stores for input and label data | ||
train_group_size: The number of passages used to train for each query | ||
inp_prefix: The input prefix | ||
passage_prefix: The passage prefix | ||
keyword_col_name: The column name for the query text | ||
content_col_names: The column names for the content fields | ||
content_sep: The separator between the content fields | ||
Returns: A tuple of formatted samples and scores | ||
""" | ||
qid = inp_id | ||
pidxs = ret_idxs | ||
|
||
input_store = table_stores["input"] | ||
label_store = table_stores["label"] | ||
|
||
# get the values of the query | ||
query = input_store[qid][keyword_col_name] | ||
mean_score = np.mean(scores) | ||
|
||
# get idxs for positive items | ||
pos_idxs = [(x, pid) for x, pid in zip(scores, pidxs) if x > mean_score] | ||
neg_idxs = [(x, pid) for x, pid in zip(scores, pidxs) if x <= mean_score] | ||
random.shuffle(pos_idxs) | ||
random.shuffle(neg_idxs) | ||
|
||
num_positives = train_group_size // 2 | ||
|
||
all_selections = pos_idxs[:num_positives] | ||
num_positives = len(all_selections) | ||
num_negatives = train_group_size - num_positives | ||
all_selections.extend(neg_idxs[:num_negatives]) | ||
|
||
if len(all_selections) < train_group_size: | ||
all_selections.extend( | ||
random.choices(neg_idxs, k=train_group_size - len(all_selections)) | ||
) | ||
|
||
all_scores = [s for s, _ in all_selections] | ||
all_pids = [pid for _, pid in all_selections] | ||
|
||
# get the values for the retrieved items | ||
ret_info = [label_store[i] for i in all_pids] | ||
|
||
formated_pair = [] | ||
for info in ret_info: | ||
formated_pair.append( | ||
cls._format_sample( | ||
query, | ||
[info[c] for c in content_col_names], | ||
inp_prefix, | ||
passage_prefix, | ||
content_sep, | ||
) | ||
) | ||
return formated_pair, all_scores | ||
|
||
@classmethod | ||
def get_parquet_rows(cls, folder_path: str) -> int: | ||
""" | ||
Returns the count of rows in parquet files by reading the | ||
metadata | ||
Args: | ||
folder_path: The folder containing the parquet files | ||
Returns: The count of rows in the parquet files | ||
""" | ||
file_list = os.listdir(folder_path) | ||
file_list = [os.path.join(folder_path, x) for x in file_list] | ||
cumulative_rowcount = sum([pq.read_metadata(fp).num_rows for fp in file_list]) | ||
|
||
return cumulative_rowcount | ||
|
||
@classmethod | ||
def get_sorted_data_files(cls, filenames: List[str], idx_colname) -> List[str]: | ||
""" | ||
Returns the list of files sorted by the id in the first row of each file | ||
Args: | ||
filenames: The list of filenames | ||
idx_colname: The column name of the id | ||
Returns: The sorted list of filenames | ||
""" | ||
# Load the datasets in streaming format and read the first id | ||
fn_ordered = [] # this containes tuples with (idx, filename) | ||
for fn in filenames: | ||
tmp_ds = load_dataset("parquet", data_files=fn, streaming=True, split="train") | ||
row = next(iter(tmp_ds.take(1))) | ||
fn_ordered.append((row[idx_colname], fn)) | ||
del tmp_ds | ||
fn_ordered = sorted(fn_ordered, key=lambda x: x[0]) | ||
|
||
return [x[1] for x in fn_ordered] |
Oops, something went wrong.