From f1e7b18fa9b07c7d2d658648494d1bf746425d98 Mon Sep 17 00:00:00 2001 From: Wenshansilvia Date: Sat, 10 Feb 2024 21:56:17 +0800 Subject: [PATCH] add bge large reranker --- ....wenshandeMacBook-Air.local.33294.XMCccIZx | Bin 0 -> 53248 bytes gomate/applications/RerankerApp.py | 26 +++++++ .../{RewriteApp.py => RewriterApp.py} | 0 gomate/applications/__init__.py | 3 +- gomate/modules/__init__.py | 1 + gomate/modules/reranker/__init__.py | 1 + gomate/modules/reranker/base_reranker.py | 14 ++++ gomate/modules/reranker/bge_large_reranker.py | 66 ++++++++++++++++++ gomate/modules/rewriter/base_rewriter.py | 2 + requirements.txt | 4 ++ tests/units/test_reranker.py | 16 +++++ 11 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 .coverage.wenshandeMacBook-Air.local.33294.XMCccIZx create mode 100644 gomate/applications/RerankerApp.py rename gomate/applications/{RewriteApp.py => RewriterApp.py} (100%) create mode 100644 gomate/modules/reranker/__init__.py create mode 100644 gomate/modules/reranker/base_reranker.py create mode 100644 gomate/modules/reranker/bge_large_reranker.py create mode 100644 tests/units/test_reranker.py diff --git a/.coverage.wenshandeMacBook-Air.local.33294.XMCccIZx b/.coverage.wenshandeMacBook-Air.local.33294.XMCccIZx new file mode 100644 index 0000000000000000000000000000000000000000..92c9a68c3d1c1e870bc767dea5f32fd9323df685 GIT binary patch literal 53248 zcmeI4U2oe|7{~1-jhob|LDLjf)zq^KNVB!gq9GwRO{`-T3c z)>eR!*TEztz5(K+z-Qow7YQ!7+Z|V2al>=`k~m#6)}_+4>~A%$U(R{_e}3nA&dG@r zU%PtQ^f+tUPRH=rY3T(?mZkHINs^SOSAkyfqR>Vr-l6|;VtdqPURu$Q6|@f}HT#{U zeO9=lEvw%bUY_|`?c~0l`9U#qbvl6s1V8`;K;Zc%;9pgh;@q5k_e;;HwYldQb?$`g z!W-*rs~c-=@1tyTx~wf&$UiJihIP2gnFFTb3xn7NhEA=I|zW zxK-zFsAADG1fsvacQa?XtsKt8X>>g>vToqP>t}yo;|AOsJfmVIs4dcvrk(N z{2gk}7`bTCuvX79Z})h=@t4@3=2GB>WiK(|1VzA0Guh&6M`Vejl5e}W+w^`*H+r5O z?CaF{dX-k{y%(pI;=+P_|8@{40*_uZJ-5G+8w#5K;Sr%D&X0~8ktM@HBC@F6W_B}3 z8Bv>sjYH9#%$Z6I!*E2RF*_(+(sGooX&yTw53wI{gsPtNqQqKbKXC^YilETo6W4YeoN*p%98& za-VKAKGord+jHWMGcz9HLePF}g z5jI+PN`a<$a1q`71|?mbU$oKb7*-=1uf#AMTBW{UB&WDe*oja@R-&YIDMk0#jEzQT zMUNHRU+EGPVYOp4!gLpnlCZe7$saNojz$*^k9%f^Gf^ZE69m?l;j+5J=_HM%6i=R%{a%!eVLzk0G^lrsSa!6V8VZs~@C3c<*f)9Ib1QG!@6zw3D=nLz zD8?%tyU}ZNx8m?Ehn{@lRNmUTxTeQ@^o#BF88Tcx1jD*+TBfJ#Wb=Cdki^yumk)`} zjVqR(=yHoFB*jjTD>kwji4FIN!Y2wDJ~s~8*l>&MZTfk?K+(5;DNELiPi{Ry;OZ2a zDjnRE(e1W_2mR1m{Hff*lX5Ws7x(|^q5oJw00ck)1V8`;KmY_l00ck)1V8`;CYXS@|Bw0q z1UE2T0|5{K0T2KI5C8!X009sH0T2KIB7pn<2m=rR0T2KI5C8!X009sH0T2KI5SV-d zxc@)-Z43`V00ck)1V8`;KmY_l00ck)1V8}y{}BTq00JNY0w4eaAOHd&00JNY0w6H? z1aSX<^4l05f&d7B00@8p2!H?xfB*=900@8p?*AhOKmY_l00ck)1V8`;KmY_l00cl_ H@(KJ0onik6 literal 0 HcmV?d00001 diff --git a/gomate/applications/RerankerApp.py b/gomate/applications/RerankerApp.py new file mode 100644 index 0000000..c946ea0 --- /dev/null +++ b/gomate/applications/RerankerApp.py @@ -0,0 +1,26 @@ +from gomate.modules import bge_large_reranker + + +class RerankerApp(): + """重排模块,评估文档的相关性并重新排序。把最有可能提供准确、相关回答的文档排在前面。 + + 实现包括 + 1. bge-reranker-large。智源开源的Rerank模型。 + 2. ... + + """ + + def __init__(self, component_name=None): + """Init required reranker according to component name.""" + self.reranker_list = ['bge_large'] + assert component_name in self.reranker_list + if component_name == 'bge_large': + self.reranker = bge_large_reranker() + + def run(self, query, contexts): + """Run the required reranker""" + if query is None: + raise ValueError('missing query') + if contexts is None: + raise ValueError('missing contexts') + return self.reranker.run(query, contexts) diff --git a/gomate/applications/RewriteApp.py b/gomate/applications/RewriterApp.py similarity index 100% rename from gomate/applications/RewriteApp.py rename to gomate/applications/RewriterApp.py diff --git a/gomate/applications/__init__.py b/gomate/applications/__init__.py index f26751f..ad60ae1 100644 --- a/gomate/applications/__init__.py +++ b/gomate/applications/__init__.py @@ -1 +1,2 @@ -from .RewriteApp import RewriterApp \ No newline at end of file +from .RewriterApp import RewriterApp +from .RerankerApp import RerankerApp \ No newline at end of file diff --git a/gomate/modules/__init__.py b/gomate/modules/__init__.py index 364da00..2bcd343 100644 --- a/gomate/modules/__init__.py +++ b/gomate/modules/__init__.py @@ -1 +1,2 @@ +from .reranker import bge_large_reranker from .rewriter import HyDE_rewriter \ No newline at end of file diff --git a/gomate/modules/reranker/__init__.py b/gomate/modules/reranker/__init__.py new file mode 100644 index 0000000..583fb8c --- /dev/null +++ b/gomate/modules/reranker/__init__.py @@ -0,0 +1 @@ +from .bge_large_reranker import bge_large_reranker \ No newline at end of file diff --git a/gomate/modules/reranker/base_reranker.py b/gomate/modules/reranker/base_reranker.py new file mode 100644 index 0000000..66e49f5 --- /dev/null +++ b/gomate/modules/reranker/base_reranker.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod + + +class base_reranker(ABC): + """Define base reranker.""" + + @abstractmethod + def __init__(self, component_name=None): + """Init required reranker according to component name.""" + ... + + def run(self, query, contexts): + """Run the required reranker""" + ... diff --git a/gomate/modules/reranker/bge_large_reranker.py b/gomate/modules/reranker/bge_large_reranker.py new file mode 100644 index 0000000..c3a2b86 --- /dev/null +++ b/gomate/modules/reranker/bge_large_reranker.py @@ -0,0 +1,66 @@ +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from tqdm import tqdm +from typing import List +import numpy as np + + +class bge_large_reranker(): + """This is bge-reranker-large.""" + + def __init__(self, + model_name_or_path: str = 'BAAI/bge-reranker-large', + use_fp16: bool = False): + """Init the hyde reranker model""" + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path) + if torch.cuda.is_available(): + self.device = torch.device('cuda') + elif torch.backends.mps.is_available(): + self.device = torch.device('mps') + else: + self.device = torch.device('cpu') + use_fp16 = False + if use_fp16: + self.model.half() + self.model = self.model.to(self.device) + self.model.eval() + + self.num_gpus = torch.cuda.device_count() + if self.num_gpus > 1: + print(f"----------using {self.num_gpus}*GPUs----------") + self.model = torch.nn.DataParallel(self.model) + + @torch.no_grad() + def run(self, query, contexts, batch_size: int = 256, + max_length: int = 512) -> List[float]: + """Get reranked contexts in runtime""" + + if self.num_gpus > 0: + batch_size = batch_size * self.num_gpus + + assert isinstance(query, str) + assert isinstance(contexts, list) + sentence_pairs = [[query, cxt] for cxt in contexts] + + all_scores = [] + for start_index in tqdm(range(0, len(sentence_pairs), batch_size), desc="Compute Reranking Scores", + disable=len(sentence_pairs) < 128): + sentences_batch = sentence_pairs[start_index:start_index + batch_size] + inputs = self.tokenizer( + sentences_batch, + padding=True, + truncation=True, + return_tensors='pt', + max_length=max_length, + ).to(self.device) + + scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float() + all_scores.extend(scores.cpu().numpy().tolist()) + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + probabilities = sigmoid(np.array(all_scores)) + print(probabilities) + return probabilities diff --git a/gomate/modules/rewriter/base_rewriter.py b/gomate/modules/rewriter/base_rewriter.py index ce2b88f..bc0a7bb 100644 --- a/gomate/modules/rewriter/base_rewriter.py +++ b/gomate/modules/rewriter/base_rewriter.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod + class base_rewriter(ABC): """Define base rewriter.""" + @abstractmethod def __init__(self, component_name=None): """Init required rewriter according to component name.""" diff --git a/requirements.txt b/requirements.txt index 1c0ade0..db5e972 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,7 @@ pydocstyle == 2.1 openai == 1.10.0 datasets == 2.16.1 langchain == 0.1.4 +transformers == 4.37.2 +torch == 2.2.0 +pandas == 2.0.0 +nltk == 3.8.1 diff --git a/tests/units/test_reranker.py b/tests/units/test_reranker.py new file mode 100644 index 0000000..05ae213 --- /dev/null +++ b/tests/units/test_reranker.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +import pytest +from gomate.applications import RerankerApp +# import os + +def test_reranker(): + component_name = 'bge_large' + model = RerankerApp(component_name = component_name) + query = "恐龙是怎么被命名的?" + contexts = ["[12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。 [12]“我们的结果显示恐龙所具有的生长速率和新陈代谢速率,既不是冷血生物体也不是温血生物体所具有的特征。它们既不像哺乳动物或者鸟类,也不像爬行动物或者鱼类,而是介于现代冷血动物和温血动物之间。简言之,它们的生理机能在现代社会并不常见。”美国亚利桑那大学进化生物学家和生态学家布莱恩·恩奎斯特说。墨西哥生物学家表示,正是这种中等程度的新陈代谢使得恐龙可以长得比任何哺乳动物都要大。温血动物需要大量进食,因此它们频繁猎捕和咀嚼植物。“很难想象霸王龙大小的狮子能够吃饱以 存活下来。","[12]哺乳动物起源于爬行动物,它们的前身是“似哺乳类的爬行动物”,即兽孔目,早期则是“似爬行类的哺乳动物”,即哺乳型动物。 [12]中生代的爬行动物,大部分在中生代的末期灭绝了;一部分适应了变化的环境被保留下来,即现存的爬行动物(如龟鳖类、蛇类、鳄类等);还有一部分沿着不同的进化方向,进化成了现今的鸟类和哺乳类。 [12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。"] + probabilities = model.run(query, contexts) + assert probabilities is not None + +if __name__ == '__main__': + test_reranker() \ No newline at end of file